• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/platform_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 // Name of the interpreter backend.
40 constexpr char kInterpreter[] = "interpreter";
41 
42 // Wrapper function that creates a nicer error message (than a bare
43 // ValueOrDie()) if the platform we intend to test is not available.
GetOrCreateLocalClientOrDie(const LocalClientOptions & client_options)44 LocalClient* GetOrCreateLocalClientOrDie(
45     const LocalClientOptions& client_options) {
46   StatusOr<LocalClient*> result =
47       ClientLibrary::GetOrCreateLocalClient(client_options);
48   TF_CHECK_OK(result.status()) << " could not create local client for testing";
49   return result.ValueOrDie();
50 }
51 
52 // Helper functions to get the reference platform.
GetReferencePlatform()53 se::Platform* GetReferencePlatform() {
54   auto result = PlatformUtil::GetPlatform(kInterpreter);
55   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
56   return result.ValueOrDie();
57 }
58 
59 }  // namespace
60 
ClientLibraryTestBase(se::Platform * platform,const LocalClientOptions & client_options)61 ClientLibraryTestBase::ClientLibraryTestBase(
62     se::Platform* platform, const LocalClientOptions& client_options)
63     : client_(GetOrCreateLocalClientOrDie(client_options)),
64       execution_options_(CreateDefaultExecutionOptions()) {
65   CHECK_EQ(platform, client_options.platform());
66 
67   LocalClientOptions ref_options;
68   ref_options.set_platform(GetReferencePlatform());
69   ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
70 
71   // Disabling constant_folding so that tests (usually written using Constants)
72   // will exercise the intended code paths, instead of being constant folded.
73   //
74   // TODO(b/38354253): Constant folding is currently disabled. Change tests to
75   // use Parameters instead of Constants, and re-enable constant folding by
76   // default.
77   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
78       "constant_folding");
79 
80   execution_options_.mutable_debug_options()
81       ->set_xla_hlo_evaluator_use_fast_path(true);
82 }
83 
ClientLibraryTestBase(se::Platform * platform)84 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
85     : execution_options_(CreateDefaultExecutionOptions()) {
86   LocalClientOptions default_options;
87   default_options.set_platform(platform);
88   client_ = GetOrCreateLocalClientOrDie(default_options);
89 
90   LocalClientOptions ref_options;
91   ref_options.set_platform(GetReferencePlatform());
92   ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
93 
94   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
95       "constant_folding");
96 
97   execution_options_.mutable_debug_options()
98       ->set_xla_hlo_evaluator_use_fast_path(true);
99 }
100 
TestName() const101 string ClientLibraryTestBase::TestName() const {
102   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
103 }
104 
Execute(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)105 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
106     XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
107   // Build the computation, as a convenience.
108   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
109   return client_->Execute(computation, arguments, &execution_options_);
110 }
111 
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)112 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
113     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
114     const Shape* shape_with_output_layout) {
115   ExecutionOptions execution_options = execution_options_;
116   if (shape_with_output_layout != nullptr) {
117     *execution_options.mutable_shape_with_output_layout() =
118         shape_with_output_layout->ToProto();
119   }
120   return client_->ExecuteAndTransfer(computation, arguments,
121                                      &execution_options);
122 }
123 
ExecuteAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)124 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
125     XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
126     const Shape* shape_with_output_layout) {
127   // Build the computation, as a convenience.
128   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
129   return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
130 }
131 
ExecuteAndTransferReference(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)132 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
133     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
134     const Shape* shape_with_output_layout) {
135   ExecutionOptions execution_options = execution_options_;
136   if (shape_with_output_layout != nullptr) {
137     *execution_options.mutable_shape_with_output_layout() =
138         shape_with_output_layout->ToProto();
139   }
140   execution_options.clear_device_handles();
141   return ref_client_->ExecuteAndTransfer(computation, arguments,
142                                          &execution_options);
143 }
144 
ExecuteToString(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)145 string ClientLibraryTestBase::ExecuteToString(
146     XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
147   auto computation_status = builder->Build();
148   if (!computation_status.ok()) {
149     return computation_status.status().ToString();
150   }
151   auto computation = computation_status.ConsumeValueOrDie();
152 
153   auto result =
154       client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
155   if (!result.ok()) {
156     return result.status().ToString();
157   } else {
158     return result.ValueOrDie().ToString();
159   }
160 }
161 
ComputeAndCompareR1(XlaBuilder * builder,const tensorflow::core::Bitmap & expected,absl::Span<GlobalData * const> arguments)162 void ClientLibraryTestBase::ComputeAndCompareR1(
163     XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
164     absl::Span<GlobalData* const> arguments) {
165   Literal expected_literal = LiteralUtil::CreateR1(expected);
166   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
167                                                   arguments);
168 }
169 
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,const Shape * shape_with_layout)170 void ClientLibraryTestBase::ComputeAndCompareLiteral(
171     XlaBuilder* builder, const Literal& expected,
172     absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) {
173   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
174                                                   shape_with_layout));
175 }
176 
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error,const Shape * shape_with_layout)177 void ClientLibraryTestBase::ComputeAndCompareLiteral(
178     XlaBuilder* builder, const Literal& expected,
179     absl::Span<GlobalData* const> arguments, ErrorSpec error,
180     const Shape* shape_with_layout) {
181   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
182                                                   error, shape_with_layout));
183 }
184 
ComputeAndCompareLiteralWithAllOutputLayouts(const xla::XlaComputation & computation,const Literal & expected,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const string & error_message)> & verify_output)185 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
186     const xla::XlaComputation& computation, const Literal& expected,
187     absl::Span<GlobalData* const> arguments,
188     const std::function<void(const Literal& actual,
189                              const string& error_message)>& verify_output) {
190   // Try with no layout requirement.
191   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
192   verify_output(actual, "");
193 
194   // Try with all output layouts.
195   std::vector<int64> minor_to_major(expected.shape().rank());
196   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
197   do {
198     auto layout = ShapeUtil::MakeShapeWithLayout(
199         expected.shape().element_type(),
200         AsInt64Slice(expected.shape().dimensions()), minor_to_major);
201     TF_ASSIGN_OR_RETURN(auto actual,
202                         ExecuteAndTransfer(computation, arguments, &layout));
203     verify_output(actual,
204                   absl::StrCat("Test with output layout: ",
205                                ShapeUtil::HumanStringWithLayout(layout)));
206   } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
207   return Status::OK();
208 }
209 
ComputeAndCompareLiteralWithAllInputLayouts(const xla::XlaComputation & computation,const Literal &,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const string & error_message)> & verify_output,const Shape * output_with_layout)210 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
211     const xla::XlaComputation& computation, const Literal& /*expected*/,
212     absl::Span<GlobalData* const> arguments,
213     const std::function<void(const Literal& actual,
214                              const string& error_message)>& verify_output,
215     const Shape* output_with_layout) {
216   std::vector<GlobalData*> arguments_with_layout;
217   std::vector<string> layout_strings;
218   // This is a recursive function. It's an std::function instead of a lambda
219   // because it needs to capture itself. The index is the index of the argument
220   // to try all layouts for.
221   std::function<Status(int64)> choose;
222   choose = [&, this](int64 index) -> Status {
223     if (index < arguments.size()) {
224       // Try out all layouts for the operand.
225       TF_ASSIGN_OR_RETURN(auto literal,
226                           client_->Transfer(*arguments[index], nullptr));
227       // Skip tuples because they don't have a rank.
228       if (literal.shape().IsTuple()) {
229         layout_strings.push_back(
230             ShapeUtil::HumanStringWithLayout(literal.shape()));
231         arguments_with_layout.push_back(arguments[index]);
232         TF_RETURN_IF_ERROR(choose(index + 1));
233         arguments_with_layout.pop_back();
234         layout_strings.pop_back();
235         return Status::OK();
236       }
237 
238       std::vector<int64> minor_to_major(literal.shape().rank());
239       std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
240       do {
241         auto literal_relayout =
242             literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
243         layout_strings.push_back(
244             ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
245         TF_ASSIGN_OR_RETURN(auto data,
246                             client_->TransferToServer(literal_relayout));
247         arguments_with_layout.push_back(data.get());
248         TF_RETURN_IF_ERROR(choose(index + 1));
249         arguments_with_layout.pop_back();
250         layout_strings.pop_back();
251       } while (
252           std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
253       return Status::OK();
254     }
255 
256     // Every argument has an assigned layout.
257     TF_ASSIGN_OR_RETURN(
258         auto actual,
259         ExecuteAndTransfer(computation,
260                            absl::Span<GlobalData* const>(arguments_with_layout),
261                            output_with_layout));
262     string error_message = "Test with input layouts: ";
263     for (const auto& str : layout_strings) {
264       absl::StrAppend(&error_message, str, " ");
265     }
266     verify_output(actual, error_message);
267     return Status::OK();
268   };
269 
270   return choose(0);
271 }
272 
ComputeAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)273 StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer(
274     XlaBuilder* builder, absl::Span<GlobalData* const> arguments_passed_in,
275     const Shape* shape_with_layout) {
276   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
277                                      arguments_passed_in.end());
278 
279   // Transfer and use elements of arguments_, if the AddParam() API was used.
280   std::vector<std::unique_ptr<GlobalData>> owning_arguments;
281   if (!arguments_.empty()) {
282     CHECK(arguments.empty());
283     for (const auto& argument : arguments_) {
284       TF_ASSIGN_OR_RETURN(
285           std::unique_ptr<GlobalData> owned_argument,
286           client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
287       owning_arguments.push_back(std::move(owned_argument));
288       arguments.push_back(owning_arguments.back().get());
289     }
290   }
291 
292   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
293   return ExecuteAndTransfer(computation, arguments, shape_with_layout);
294 }
295 
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)296 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
297     XlaBuilder* builder, const Literal& expected,
298     absl::Span<GlobalData* const> arguments_passed_in,
299     const Shape* shape_with_layout) {
300   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
301                                      arguments_passed_in.end());
302 
303   // Transfer and use elements of arguments_, if the AddParam() API was used.
304   std::vector<std::unique_ptr<GlobalData>> owning_arguments;
305   if (!arguments_.empty()) {
306     CHECK(arguments.empty());
307     for (const auto& argument : arguments_) {
308       TF_ASSIGN_OR_RETURN(
309           std::unique_ptr<GlobalData> owned_argument,
310           client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
311       owning_arguments.push_back(std::move(owned_argument));
312       arguments.push_back(owning_arguments.back().get());
313     }
314   }
315 
316   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
317   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
318       ShapeUtil::ElementIsComplex(expected.shape())) {
319     LOG(WARNING) << "performing exact comparison of floating point numbers";
320   }
321   // We allow using a float expected literal for a bfloat16 output. In this
322   // case, we need to convert the expected literal to bfloat16.
323   const Literal* expected_ptr = &expected;
324   Literal converted_expected;
325   Shape layout_shape;
326   if (use_bfloat16_) {
327     converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
328     expected_ptr = &converted_expected;
329     if (shape_with_layout != nullptr) {
330       layout_shape = *shape_with_layout;
331       ShapeUtil::ForEachMutableSubshape(
332           &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
333             if (subshape->element_type() == F32) {
334               subshape->set_element_type(BF16);
335             }
336           });
337       shape_with_layout = &layout_shape;
338     }
339   }
340   auto expect_equal = [&](const Literal& actual, const string& error_message) {
341     EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
342   };
343   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
344     return ComputeAndCompareLiteralWithAllOutputLayouts(
345         computation, *expected_ptr, arguments, expect_equal);
346   }
347   if (execution_options_.debug_options().xla_test_all_input_layouts()) {
348     return ComputeAndCompareLiteralWithAllInputLayouts(
349         computation, *expected_ptr, arguments, expect_equal, shape_with_layout);
350   }
351   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
352                                                       shape_with_layout));
353   EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
354   return Status::OK();
355 }
356 
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,ErrorSpec error,const Shape * shape_with_layout)357 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
358     XlaBuilder* builder, const Literal& expected,
359     absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error,
360     const Shape* shape_with_layout) {
361   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
362                                      arguments_passed_in.end());
363 
364   // Transfer and use elements of arguments_, if the AddParam() API was used.
365   std::vector<std::unique_ptr<GlobalData>> owning_arguments;
366   if (!arguments_.empty()) {
367     CHECK(arguments.empty());
368     for (const auto& argument : arguments_) {
369       TF_ASSIGN_OR_RETURN(
370           std::unique_ptr<GlobalData> owned_argument,
371           client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
372       owning_arguments.push_back(std::move(owned_argument));
373       arguments.push_back(owning_arguments.back().get());
374     }
375   }
376 
377   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
378   // We allow using a float expected literal for a bfloat16 output. In this
379   // case, we need to convert the expected literal to bfloat16.
380   const Literal* expected_ptr = &expected;
381   Literal converted_expected;
382   Shape layout_shape;
383   if (use_bfloat16_) {
384     converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
385     expected_ptr = &converted_expected;
386     if (shape_with_layout != nullptr) {
387       layout_shape = *shape_with_layout;
388       ShapeUtil::ForEachMutableSubshape(
389           &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
390             if (subshape->element_type() == F32) {
391               subshape->set_element_type(BF16);
392             }
393           });
394       shape_with_layout = &layout_shape;
395     }
396   }
397   auto expect_near = [&](const Literal& actual, const string& error_message) {
398     EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
399         << error_message;
400   };
401   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
402     return ComputeAndCompareLiteralWithAllOutputLayouts(
403         computation, *expected_ptr, arguments, expect_near);
404   }
405   if (execution_options_.debug_options().xla_test_all_input_layouts()) {
406     return ComputeAndCompareLiteralWithAllInputLayouts(
407         computation, *expected_ptr, arguments, expect_near, shape_with_layout);
408   }
409   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
410                                                       shape_with_layout));
411   EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
412   return Status::OK();
413 }
414 
ComputeAndCompareR1U8(XlaBuilder * builder,absl::string_view expected,absl::Span<GlobalData * const> arguments)415 void ClientLibraryTestBase::ComputeAndCompareR1U8(
416     XlaBuilder* builder, absl::string_view expected,
417     absl::Span<GlobalData* const> arguments) {
418   auto actual_status = ExecuteAndTransfer(builder, arguments);
419   EXPECT_IS_OK(actual_status.status());
420   if (!actual_status.ok()) {
421     return;
422   }
423   auto actual = actual_status.ConsumeValueOrDie();
424 
425   // Turn the expected value into a literal.
426   Literal expected_literal = LiteralUtil::CreateR1U8(expected);
427 
428   VLOG(1) << "expected: " << expected_literal.ToString();
429   VLOG(1) << "actual:   " << actual.ToString();
430 
431   EXPECT_EQ(expected, actual.GetR1U8AsString());
432 }
433 
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments)434 void ClientLibraryTestBase::ComputeAndCompareTuple(
435     XlaBuilder* builder, const Literal& expected,
436     absl::Span<GlobalData* const> arguments) {
437   auto actual_status = ExecuteAndTransfer(builder, arguments);
438   EXPECT_IS_OK(actual_status.status());
439   if (!actual_status.ok()) {
440     return;
441   }
442   auto actual = actual_status.ConsumeValueOrDie();
443   EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
444 }
445 
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)446 void ClientLibraryTestBase::ComputeAndCompareTuple(
447     XlaBuilder* builder, const Literal& expected,
448     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
449   auto actual_status = ExecuteAndTransfer(builder, arguments);
450   EXPECT_IS_OK(actual_status.status());
451   if (!actual_status.ok()) {
452     return;
453   }
454   auto actual = actual_status.ConsumeValueOrDie();
455   EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
456 }
457 
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments)458 void ClientLibraryTestBase::ComputeAndCompare(
459     XlaBuilder* builder, absl::Span<const Literal> arguments) {
460   auto status_or_data = ComputeValueAndReference(builder, arguments);
461   EXPECT_IS_OK(status_or_data);
462   if (!status_or_data.ok()) {
463     return;
464   }
465   Literal reference, result;
466   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
467   EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
468 }
469 
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments,ErrorSpec error)470 void ClientLibraryTestBase::ComputeAndCompare(
471     XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) {
472   auto status_or_data = ComputeValueAndReference(builder, arguments);
473   EXPECT_IS_OK(status_or_data);
474   if (!status_or_data.ok()) {
475     return;
476   }
477   Literal reference, result;
478   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
479   EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
480 }
481 
482 StatusOr<std::pair<Literal, Literal>>
ComputeValueAndReference(XlaBuilder * builder,absl::Span<const Literal> arguments)483 ClientLibraryTestBase::ComputeValueAndReference(
484     XlaBuilder* builder, absl::Span<const Literal> arguments) {
485   // Transfer the arguments to the executor service. We put the unique_ptr's
486   // into a vector to keep the data alive on the service until the end of this
487   // function.
488   std::vector<std::unique_ptr<GlobalData>> argument_data;
489   std::vector<std::unique_ptr<GlobalData>> ref_argument_data;
490 
491   // Use `arguments_` if the AddParam() API was used.  Otherwise, use
492   // plain `arguments`.
493   if (!arguments_.empty()) {
494     CHECK_EQ(arguments.size(), 0);
495     arguments = arguments_;
496   }
497 
498   for (const auto& arg : arguments) {
499     TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone()));
500     TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg));
501     argument_data.push_back(std::move(data));
502     ref_argument_data.push_back(std::move(ref_data));
503   }
504 
505   // Create raw pointers to the GlobalData for the rest of the call stack.
506   std::vector<GlobalData*> argument_data_ptr;
507   std::transform(
508       argument_data.begin(), argument_data.end(),
509       std::back_inserter(argument_data_ptr),
510       [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
511   std::vector<GlobalData*> ref_argument_data_ptr;
512   std::transform(
513       ref_argument_data.begin(), ref_argument_data.end(),
514       std::back_inserter(ref_argument_data_ptr),
515       [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
516 
517   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
518 
519   TF_ASSIGN_OR_RETURN(auto result,
520                       ExecuteAndTransfer(computation, argument_data_ptr));
521 
522   TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference(
523                                           computation, ref_argument_data_ptr));
524 
525   return std::make_pair(std::move(reference), std::move(result));
526 }
527 
CreateScalarRelu()528 XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
529   XlaBuilder builder("relu");
530   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
531   auto z_value = Parameter(&builder, 0, shape, "z_value");
532   auto zero = use_bfloat16_
533                   ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
534                   : ConstantR0<float>(&builder, 0.0f);
535   Max(z_value, zero);
536   auto computation_status = builder.Build();
537   TF_CHECK_OK(computation_status.status());
538   return computation_status.ConsumeValueOrDie();
539 }
540 
CreateScalarMax()541 XlaComputation ClientLibraryTestBase::CreateScalarMax() {
542   XlaBuilder builder("max");
543   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
544   auto x = Parameter(&builder, 0, shape, "x");
545   auto y = Parameter(&builder, 1, shape, "y");
546   Max(x, y);
547   auto computation_status = builder.Build();
548   TF_CHECK_OK(computation_status.status());
549   return computation_status.ConsumeValueOrDie();
550 }
551 
CreateScalarReluSensitivity()552 XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
553   XlaBuilder builder("relu_sensitivity");
554   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
555   auto activation = Parameter(&builder, 0, shape, "activation");
556   auto backprop = Parameter(&builder, 1, shape, "backprop");
557   auto zero = use_bfloat16_
558                   ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
559                   : ConstantR0<float>(&builder, 0.0f);
560   auto activation_gtz = Gt(activation, zero);
561   Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
562 
563   auto computation_status = builder.Build();
564   TF_CHECK_OK(computation_status.status());
565   return computation_status.ConsumeValueOrDie();
566 }
567 
CreatePatternedMatrix(int rows,int cols,float offset)568 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
569     int rows, int cols, float offset) {
570   auto array = absl::make_unique<Array2D<float>>(rows, cols);
571   for (int64 row = 0; row < rows; ++row) {
572     for (int64 col = 0; col < cols; ++col) {
573       (*array)(row, col) = col + (row * 1000.0f) + offset;
574     }
575   }
576   return array;
577 }
578 
579 std::unique_ptr<Array2D<float>>
CreatePatternedMatrixWithZeroPadding(int rows,int cols,int rows_padded,int cols_padded)580 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
581                                                             int rows_padded,
582                                                             int cols_padded) {
583   CHECK_GE(rows_padded, rows);
584   CHECK_GE(cols_padded, cols);
585   auto array = absl::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0);
586   for (int64 row = 0; row < rows; ++row) {
587     for (int64 col = 0; col < cols; ++col) {
588       (*array)(row, col) = col + (row * 1000.0f);
589     }
590   }
591   return array;
592 }
593 
AddParam(const Literal & argument,XlaBuilder * builder)594 XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
595                                       XlaBuilder* builder) {
596   arguments_.push_back(argument.Clone());
597   return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
598                    MaybeConvertShapeToBfloat16(argument.shape()), "");
599 }
600 
CreateConstantFromLiteral(const Literal & literal,XlaBuilder * builder)601 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
602                                                        XlaBuilder* builder) {
603   return ConstantLiteral(builder, use_bfloat16_
604                                       ? LiteralUtil::ConvertF32ToBF16(literal)
605                                       : LiteralSlice(literal));
606 }
607 
608 std::unique_ptr<GlobalData>
CreateParameterAndTransferLiteral(int64 parameter_number,const Literal & literal,const string & name,XlaBuilder * builder,XlaOp * data_handle)609 ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
610                                                          const Literal& literal,
611                                                          const string& name,
612                                                          XlaBuilder* builder,
613                                                          XlaOp* data_handle) {
614   return CreateParameterAndTransferLiteral(parameter_number, literal, name,
615                                            nullptr, builder, data_handle);
616 }
617 
MaybeConvertShapeToBfloat16(const Shape & shape)618 Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
619   if (!use_bfloat16_) {
620     return shape;
621   }
622   Shape new_shape = shape;
623   ShapeUtil::ForEachMutableSubshape(&new_shape,
624                                     [](Shape* subshape, const ShapeIndex&) {
625                                       if (subshape->element_type() == F32) {
626                                         subshape->set_element_type(BF16);
627                                       }
628                                     });
629   return new_shape;
630 }
631 
MaybeConvertLiteralToBfloat16(const Literal & literal)632 Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
633     const Literal& literal) {
634   if (use_bfloat16_) {
635     return LiteralUtil::ConvertF32ToBF16(literal);
636   }
637   return literal.Clone();
638 }
639 
640 std::unique_ptr<GlobalData>
CreateParameterAndTransferLiteral(int64 parameter_number,const Literal & literal,const string & name,const DeviceHandle * device_handle,XlaBuilder * builder,XlaOp * data_handle)641 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
642     int64 parameter_number, const Literal& literal, const string& name,
643     const DeviceHandle* device_handle, XlaBuilder* builder,
644     XlaOp* data_handle) {
645   Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
646   std::unique_ptr<GlobalData> data =
647       client_->TransferToServer(param_literal, device_handle)
648           .ConsumeValueOrDie();
649   *data_handle =
650       Parameter(builder, parameter_number, param_literal.shape(), name);
651   return data;
652 }
653 
654 }  // namespace xla
655