• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/global_data.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
38 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
39 #include "tensorflow/compiler/xla/tests/test_utils.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/bitmap.h"
42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace xla {
47 
48 // Sets the use_bfloat16 on a container of test cases according to the values in
49 // use_bfloat16_params. Generates one set of test cases for each values in
50 // use_bfloat16_params with that value. Returns the result.
51 template <typename TestCase>
ExpandUseBfloat16(absl::Span<const bool> use_bfloat16_params,absl::Span<const TestCase> specs)52 std::vector<TestCase> ExpandUseBfloat16(
53     absl::Span<const bool> use_bfloat16_params,
54     absl::Span<const TestCase> specs) {
55   std::vector<TestCase> expanded;
56   for (bool use_bfloat16 : use_bfloat16_params) {
57     for (const auto& spec : specs) {
58       expanded.push_back(spec);
59       expanded.back().use_bfloat16 = use_bfloat16;
60     }
61   }
62   return expanded;
63 }
64 
65 // A client library test establishes an in-process XLA client connection.
66 class ClientLibraryTestBase : public ManifestCheckingTest {
67  protected:
68   explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
69 
70   // Creates a new ClientLibraryTestBase with custom client options.
71   ClientLibraryTestBase(se::Platform* platform,
72                         const LocalClientOptions& client_options);
73 
74   // Returns the name of the test currently being run.
75   string TestName() const;
76 
SetFastMathDisabled(bool disabled)77   void SetFastMathDisabled(bool disabled) {
78     auto* opts = execution_options_.mutable_debug_options();
79     opts->set_xla_cpu_enable_fast_math(!disabled);
80     opts->set_xla_cpu_enable_fast_min_max(!disabled);
81     opts->set_xla_gpu_enable_fast_min_max(!disabled);
82   }
83 
SetSeed(uint64 seed)84   void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
85 
86   // Provides mutable access to the execution DebugOptions field; this lets
87   // tests tweak the options that will be used to compile/run the graph.
mutable_debug_options()88   DebugOptions* mutable_debug_options() {
89     return execution_options_.mutable_debug_options();
90   }
91 
92   // TODO(b/25566808): Add helper that populates a literal from a testdata file.
93 
94   // Convenience methods for building and running a computation with the member
95   // execution options. Modify execution_options_ in your test if you want to
96   // customize the options.
97   StatusOr<std::unique_ptr<GlobalData>> Execute(
98       XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
99 
100   StatusOr<Literal> ExecuteAndTransfer(
101       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
102       const Shape* shape_with_output_layout = nullptr);
103 
104   StatusOr<Literal> ExecuteAndTransfer(
105       const XlaComputation& computation,
106       absl::Span<GlobalData* const> arguments,
107       const Shape* shape_with_output_layout = nullptr);
108 
109   // This executes the computation via the reference client (which connects a
110   // interpreter backend). The result is used as the expected value of the
111   // computation.
112   StatusOr<Literal> ExecuteAndTransferReference(
113       const XlaComputation& computation,
114       absl::Span<GlobalData* const> arguments,
115       const Shape* shape_with_output_layout = nullptr);
116 
117   // Run a computation and return its value as a string. If an error
118   // occurs, then instead return the error as a string.
119   string ExecuteToString(XlaBuilder* builder,
120                          absl::Span<GlobalData* const> arguments);
121 
122   // Convenience methods for building and running a computation, transferring
123   // the result, and comparing it to the expected value(s). Methods are
124   // templated on the native host type which maps to specific XLA types (See
125   // XlaBuilder for details). For each rank, two forms are
126   // provided: one for floating point types with an ErrorSpec parameter, and one
127   // for integral types without the ErrorSpec parameter.
128   template <typename NativeT>
129   void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
130                            absl::Span<GlobalData* const> arguments);
131   template <typename NativeT>
132   void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
133                            absl::Span<GlobalData* const> arguments,
134                            ErrorSpec error);
135 
136   template <typename NativeT>
137   void ComputeAndCompareR1(XlaBuilder* builder,
138                            absl::Span<const NativeT> expected,
139                            absl::Span<GlobalData* const> arguments);
140   template <typename NativeT>
141   void ComputeAndCompareR1(XlaBuilder* builder,
142                            absl::Span<const NativeT> expected,
143                            absl::Span<GlobalData* const> arguments,
144                            ErrorSpec error);
145 
146   // As above, but uses a bitmap to hold the predicate vector to avoid
147   // deficiencies of vector<bool>.
148   void ComputeAndCompareR1(XlaBuilder* builder,
149                            const tensorflow::core::Bitmap& expected,
150                            absl::Span<GlobalData* const> arguments);
151 
152   template <typename NativeT>
153   void ComputeAndCompareR2(XlaBuilder* builder,
154                            const Array2D<NativeT>& expected,
155                            absl::Span<GlobalData* const> arguments);
156   template <typename NativeT>
157   void ComputeAndCompareR2(XlaBuilder* builder,
158                            const Array2D<NativeT>& expected,
159                            absl::Span<GlobalData* const> arguments,
160                            ErrorSpec error);
161 
162   template <typename NativeT>
163   void ComputeAndCompareR3(XlaBuilder* builder,
164                            const Array3D<NativeT>& expected,
165                            absl::Span<GlobalData* const> arguments);
166   template <typename NativeT>
167   void ComputeAndCompareR3(XlaBuilder* builder,
168                            const Array3D<NativeT>& expected,
169                            absl::Span<GlobalData* const> arguments,
170                            ErrorSpec error);
171 
172   template <typename NativeT>
173   void ComputeAndCompareR4(XlaBuilder* builder,
174                            const Array4D<NativeT>& expected,
175                            absl::Span<GlobalData* const> arguments);
176   template <typename NativeT>
177   void ComputeAndCompareR4(XlaBuilder* builder,
178                            const Array4D<NativeT>& expected,
179                            absl::Span<GlobalData* const> arguments,
180                            ErrorSpec error);
181 
182   // Build and run the computation and compare the result with the given
183   // literal. shape_with_layout indicates the result layout to request when
184   // calling Execute.
185   void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
186                                 absl::Span<GlobalData* const> arguments,
187                                 const Shape* shape_with_layout = nullptr);
188   void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
189                                 absl::Span<GlobalData* const> arguments,
190                                 ErrorSpec error,
191                                 const Shape* shape_with_layout = nullptr);
192 
193   // Build and run the computation and return the result as a literal.
194   // shape_with_layout indicates the result layout to request when calling
195   // Execute.
196   StatusOr<Literal> ComputeAndTransfer(
197       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
198       const Shape* shape_with_layout = nullptr);
199 
200   // ComputeAndCompare variant which returns an error status.
201   Status ComputeAndCompareLiteralWithStatus(
202       XlaBuilder* builder, const Literal& expected,
203       absl::Span<GlobalData* const> arguments,
204       const Shape* shape_with_layout = nullptr);
205   Status ComputeAndCompareLiteralWithStatus(
206       XlaBuilder* builder, const Literal& expected,
207       absl::Span<GlobalData* const> arguments, ErrorSpec error,
208       const Shape* shape_with_layout = nullptr);
209 
210   // Compare the result of the computation to a strings. In XLA strings are
211   // represented using rank-1 U8 shapes.
212   void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
213                              absl::Span<GlobalData* const> arguments);
214 
215   // Convenience method for running a built computation, transferring the
216   // result, and comparing it to the expected tuple literal.
217   void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
218                               absl::Span<GlobalData* const> arguments);
219   void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
220                               absl::Span<GlobalData* const> arguments,
221                               ErrorSpec error);
222 
223   // Convenience method for running a built computation and comparing the result
224   // with the reference result.
225   void ComputeAndCompare(XlaBuilder* builder,
226                          absl::Span<const Literal> arguments);
227   void ComputeAndCompare(XlaBuilder* builder,
228                          absl::Span<const Literal> arguments, ErrorSpec error);
229   template <typename NativeT>
230   void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
231                          absl::Span<GlobalData* const> arguments);
232   template <typename NativeT>
233   void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
234                          absl::Span<GlobalData* const> arguments,
235                          ErrorSpec error);
236   // Create scalar operations for use in reductions.
237   XlaComputation CreateScalarRelu();
238   XlaComputation CreateScalarMax();
239   XlaComputation CreateScalarReluSensitivity();
240 
241   // Special case convenience functions for creating filled arrays.
242 
243   // Creates an array of pseudorandom values lying between the given minimum and
244   // maximum values.
245   template <typename NativeT>
246   std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
247                                             NativeT max_value, uint32 seed);
248   template <typename NativeT>
249   std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
250                                                          const int cols,
251                                                          NativeT min_value,
252                                                          NativeT max_value,
253                                                          uint32 seed);
254 
255   // Creates a (rows x cols) array filled in the following form:
256   //
257   //  [      0              1 ...                   cols-1]
258   //  [  1,000          1,001 ...          1000.0 + cols-1]
259   //  [    ...            ... ...                      ...]
260   //  [(rows-1)*1000.0    ... ... (rows-1)*1000.0 + cols-1]
261   //
262   // If provided, offset is added uniformly to every element (e.g. an offset of
263   // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
264   std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
265                                                         const int cols,
266                                                         float offset = 0.0);
267 
268   // Creates a (rows x cols) array as above, padded out to
269   // (rows_padded x cols_padded) with zeroes.  Requires rows_padded >= rows
270   // and cols_padded > cols.
271   std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
272       const int rows, const int cols, const int rows_padded,
273       const int cols_padded);
274 
275   // Creates a parameter instruction, transfers the literal for the parameter to
276   // server, then stores into "data_handle" the global handle for that
277   // parameter. When the use_bfloat16 flag is set but the literal has F32
278   // elements, the literal will be converted to BF16 before being transferred.
279   StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
280       int64 parameter_number, const Literal& literal, const string& name,
281       XlaBuilder* builder, XlaOp* data_handle);
282 
283   // As above, but the caller can specify the device that the literal is
284   // transferred to. If device_handle is nullptr, the literal will be
285   // transferred to the default device.
286   StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
287       int64 parameter_number, const Literal& literal, const string& name,
288       const DeviceHandle* device_handle, XlaBuilder* builder,
289       XlaOp* data_handle);
290 
291   // Creates a parameter instruction and sets the value that will be passed to
292   // the computation as specified. This function must be used for all parameters
293   // or none and no parameters must be passed when invoking the computation if
294   // using this mechanism. If using this mechanism, then each parameter must be
295   // set exactly once. The first added parameter gets index 0, then 1 and so on.
296   XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
297 
298   template <class T>
AddParam(const Array<T> & argument,XlaBuilder * builder)299   XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
300     return AddParam(LiteralUtil::CreateFromArray(argument), builder);
301   }
302 
303   // Creates a constant instruction with the given literal. When the
304   // use_bfloat16 flag is set but the literal has F32 elements, the elements
305   // will be converted to BF16s.
306   XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
307 
308   // Creates a constant instruction with the given array. When the use_bfloat16
309   // flag is set but the array has float elements, the elements will be
310   // converted to bfloat16s.
311 
312   template <typename NativeT>
CreateConstantFromArray(const Array<NativeT> & array,XlaBuilder * builder)313   XlaOp CreateConstantFromArray(const Array<NativeT>& array,
314                                 XlaBuilder* builder) {
315     return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
316                                      builder);
317   }
318 
319   // Same as CreateConstantFromArray, but for scalars.
320   template <typename NativeT>
CreateConstantFromScalar(NativeT value,XlaBuilder * builder)321   XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
322     return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
323                                      builder);
324   }
325 
326   // Creates a parameter instruction that wraps a given value and then stores
327   // into "data_handle" the global handle for that parameter.
328   //
329   // "parameter_number" is the parameter number.
330   // "name" is the name of the parameter instruction.
331   //
332   // When the use_bfloat16 flag is set but NativeT is float, the data will be
333   // converted to bfloat16.
334   template <typename NativeT>
335   std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
336                                                 int64 parameter_number,
337                                                 const string& name,
338                                                 XlaBuilder* builder,
339                                                 XlaOp* data_handle);
340 
341   // Creates a parameter instruction that wraps the given values and then stores
342   // into "data_handle" the global handle for that parameter.
343   //
344   // "parameter_number" is the parameter number.
345   // "name" is the name of the parameter instruction.
346   //
347   // When the use_bfloat16 flag is set but NativeT is float, the data will be
348   // converted to bfloat16.
349   template <typename NativeT>
350   std::unique_ptr<GlobalData> CreateR1Parameter(
351       absl::Span<const NativeT> values, int64 parameter_number,
352       const string& name, XlaBuilder* builder, XlaOp* data_handle);
353 
354   // Creates a parameter instruction that wraps the given constant array
355   // "array_2d" and then stores it to the global handle for that parameter
356   // "data_handle".
357   //
358   // "parameter_number" is the parameter number.
359   // "name" is the name of the parameter instruction.
360   //
361   // When the use_bfloat16 flag is set but NativeT is float, the data will be
362   // converted to bfloat16.
363   template <typename NativeT>
364   std::unique_ptr<GlobalData> CreateR2Parameter(
365       const Array2D<NativeT>& array_2d, int64 parameter_number,
366       const string& name, XlaBuilder* builder, XlaOp* data_handle);
367 
368   // Creates a parameter instruction that wraps the given constant array
369   // "array_3d" and then stores it to the global handle for that parameter
370   // "data_handle".
371   //
372   // "parameter_number" is the parameter number.
373   // "name" is the name of the parameter instruction.
374   //
375   // When the use_bfloat16 flag is set but NativeT is float, the data will be
376   // converted to bfloat16.
377   template <typename NativeT>
378   std::unique_ptr<GlobalData> CreateR3Parameter(
379       const Array3D<NativeT>& array_3d, int64 parameter_number,
380       const string& name, XlaBuilder* builder, XlaOp* data_handle);
381 
382   // Creates a parameter instruction that wraps the given constant array
383   // "array_4d" and then stores it to the global handle for that parameter
384   // "data_handle".
385   //
386   // "parameter_number" is the parameter number.
387   // "name" is the name of the parameter instruction.
388   //
389   // When the use_bfloat16 flag is set but NativeT is float, the data will be
390   // converted to bfloat16.
391   template <typename NativeT>
392   std::unique_ptr<GlobalData> CreateR4Parameter(
393       const Array4D<NativeT>& array_4d, int64 parameter_number,
394       const string& name, XlaBuilder* builder, XlaOp* data_handle);
395 
396   template <typename NativeT>
397   std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
398                                               int64 parameter_number,
399                                               const string& name,
400                                               XlaBuilder* builder,
401                                               XlaOp* data_handle);
402 
403   // Getter and setter for the use_bfloat16 flag, which indicates whether to run
404   // tests with all float-type input/output converted to bfloat16.
use_bfloat16()405   bool use_bfloat16() const { return use_bfloat16_; }
set_use_bfloat16(bool value)406   void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
407 
408   // The float type used in this test, BF16 or F32 according to use_bfloat16.
FloatType()409   PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
410 
411   // Executes the computation and calculates the expected reference value using
412   // the reference client. Returns two literals in the order of (expected,
413   // actual).
414   StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
415       XlaBuilder* builder, absl::Span<const Literal> arguments);
416 
417   // Converts an f32 literal to bf16 if use_bfloat16_ is true.
418   Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
419 
420   LocalClient* client_;
421   LocalClient* ref_client_;  // To compute reference result.
422   ExecutionOptions execution_options_;
423 
424  private:
425   Status ComputeAndCompareLiteralWithAllOutputLayouts(
426       const xla::XlaComputation& computation, const Literal& expected,
427       absl::Span<GlobalData* const> arguments,
428       const std::function<void(const Literal& actual,
429                                const string& error_message)>& verify_output);
430   Status ComputeAndCompareLiteralWithAllInputLayouts(
431       const xla::XlaComputation& computation, const Literal& expected,
432       absl::Span<GlobalData* const> arguments,
433       const std::function<void(const Literal& actual,
434                                const string& error_message)>& verify_output,
435       const Shape* output_with_layout = nullptr);
436 
437   // Converts an f32 shape to bf16 if use_bfloat16_ is true.
438   Shape MaybeConvertShapeToBfloat16(const Shape& shape);
439 
440   // Whether to run tests with all float-type input/output converted to
441   // bfloat16.
442   bool use_bfloat16_ = false;
443 
444   // Arguments to be passed to the computation when it runs.
445   std::vector<Literal> arguments_;
446 };
447 
448 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments)449 void ClientLibraryTestBase::ComputeAndCompareR0(
450     XlaBuilder* builder, NativeT expected,
451     absl::Span<GlobalData* const> arguments) {
452   Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
453   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
454                                                   arguments);
455 }
456 
457 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)458 void ClientLibraryTestBase::ComputeAndCompareR0(
459     XlaBuilder* builder, NativeT expected,
460     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
461   static_assert(std::is_same<NativeT, float>::value ||
462                     std::is_same<NativeT, double>::value ||
463                     std::is_same<NativeT, bfloat16>::value ||
464                     std::is_same<NativeT, half>::value ||
465                     std::is_same<NativeT, complex64>::value ||
466                     std::is_same<NativeT, complex128>::value,
467                 "Float or complex type required when specifying an ErrorSpec");
468   Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
469   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
470                                                   arguments, error);
471 }
472 
473 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)474 void ClientLibraryTestBase::ComputeAndCompareR1(
475     XlaBuilder* builder, absl::Span<const NativeT> expected,
476     absl::Span<GlobalData* const> arguments) {
477   Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
478   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
479                                                   arguments);
480 }
481 
482 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)483 void ClientLibraryTestBase::ComputeAndCompareR1(
484     XlaBuilder* builder, absl::Span<const NativeT> expected,
485     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
486   static_assert(std::is_same<NativeT, float>::value ||
487                     std::is_same<NativeT, double>::value ||
488                     std::is_same<NativeT, bfloat16>::value ||
489                     std::is_same<NativeT, half>::value ||
490                     std::is_same<NativeT, complex64>::value ||
491                     std::is_same<NativeT, complex128>::value,
492                 "Float or complex type required when specifying an ErrorSpec");
493   Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
494   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
495                                                   arguments, error);
496 }
497 
498 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)499 void ClientLibraryTestBase::ComputeAndCompareR2(
500     XlaBuilder* builder, const Array2D<NativeT>& expected,
501     absl::Span<GlobalData* const> arguments) {
502   Literal expected_literal =
503       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
504   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
505                                                   arguments);
506 }
507 
508 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)509 void ClientLibraryTestBase::ComputeAndCompareR2(
510     XlaBuilder* builder, const Array2D<NativeT>& expected,
511     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
512   static_assert(std::is_same<NativeT, float>::value ||
513                     std::is_same<NativeT, double>::value ||
514                     std::is_same<NativeT, bfloat16>::value ||
515                     std::is_same<NativeT, half>::value ||
516                     std::is_same<NativeT, complex64>::value ||
517                     std::is_same<NativeT, complex128>::value,
518                 "Float or complex type required when specifying an ErrorSpec");
519   Literal expected_literal =
520       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
521   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
522                                                   arguments, error);
523 }
524 
525 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments)526 void ClientLibraryTestBase::ComputeAndCompareR3(
527     XlaBuilder* builder, const Array3D<NativeT>& expected,
528     absl::Span<GlobalData* const> arguments) {
529   Literal expected_literal =
530       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
531   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
532                                                   arguments);
533 }
534 
535 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)536 void ClientLibraryTestBase::ComputeAndCompareR3(
537     XlaBuilder* builder, const Array3D<NativeT>& expected,
538     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
539   static_assert(std::is_same<NativeT, float>::value ||
540                     std::is_same<NativeT, double>::value ||
541                     std::is_same<NativeT, bfloat16>::value ||
542                     std::is_same<NativeT, half>::value ||
543                     std::is_same<NativeT, complex64>::value ||
544                     std::is_same<NativeT, complex128>::value,
545                 "Float or complex type required when specifying an ErrorSpec");
546   Literal expected_literal =
547       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
548   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
549                                                   arguments, error);
550 }
551 
552 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments)553 void ClientLibraryTestBase::ComputeAndCompareR4(
554     XlaBuilder* builder, const Array4D<NativeT>& expected,
555     absl::Span<GlobalData* const> arguments) {
556   Literal expected_literal =
557       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
558   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
559                                                   arguments);
560 }
561 
562 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)563 void ClientLibraryTestBase::ComputeAndCompareR4(
564     XlaBuilder* builder, const Array4D<NativeT>& expected,
565     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
566   static_assert(std::is_same<NativeT, float>::value ||
567                     std::is_same<NativeT, double>::value ||
568                     std::is_same<NativeT, bfloat16>::value ||
569                     std::is_same<NativeT, half>::value ||
570                     std::is_same<NativeT, complex64>::value ||
571                     std::is_same<NativeT, complex128>::value,
572                 "Float or complex type required when specifying an ErrorSpec");
573   Literal expected_literal =
574       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
575   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
576                                                   arguments, error);
577 }
578 
579 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments)580 void ClientLibraryTestBase::ComputeAndCompare(
581     XlaBuilder* builder, const Array<NativeT>& expected,
582     absl::Span<GlobalData* const> arguments) {
583   Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
584   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
585                                                   arguments);
586 }
587 
588 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)589 void ClientLibraryTestBase::ComputeAndCompare(
590     XlaBuilder* builder, const Array<NativeT>& expected,
591     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
592   static_assert(std::is_same<NativeT, float>::value ||
593                     std::is_same<NativeT, double>::value ||
594                     std::is_same<NativeT, bfloat16>::value ||
595                     std::is_same<NativeT, half>::value ||
596                     std::is_same<NativeT, complex64>::value ||
597                     std::is_same<NativeT, complex128>::value,
598                 "Float or complex type required when specifying an ErrorSpec");
599   Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
600   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
601                                                   arguments, error);
602 }
603 
604 template <typename NativeT>
CreateR0Parameter(NativeT value,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)605 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
606     NativeT value, int64 parameter_number, const string& name,
607     XlaBuilder* builder, XlaOp* data_handle) {
608   Literal literal = LiteralUtil::CreateR0(value);
609   if (use_bfloat16_ && literal.shape().element_type() == F32) {
610     literal = LiteralUtil::ConvertF32ToBF16(literal);
611   }
612   std::unique_ptr<GlobalData> data =
613       client_->TransferToServer(literal).ConsumeValueOrDie();
614   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
615   return data;
616 }
617 
618 template <typename NativeT>
CreateR1Parameter(absl::Span<const NativeT> values,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)619 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
620     absl::Span<const NativeT> values, int64 parameter_number,
621     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
622   Literal literal = LiteralUtil::CreateR1(values);
623   if (use_bfloat16_ && literal.shape().element_type() == F32) {
624     literal = LiteralUtil::ConvertF32ToBF16(literal);
625   }
626   std::unique_ptr<GlobalData> data =
627       client_->TransferToServer(literal).ConsumeValueOrDie();
628   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
629   return data;
630 }
631 
632 template <typename NativeT>
CreateR2Parameter(const Array2D<NativeT> & array_2d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)633 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
634     const Array2D<NativeT>& array_2d, int64 parameter_number,
635     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
636   Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
637   if (use_bfloat16_ && literal.shape().element_type() == F32) {
638     literal = LiteralUtil::ConvertF32ToBF16(literal);
639   }
640   std::unique_ptr<GlobalData> data =
641       client_->TransferToServer(literal).ConsumeValueOrDie();
642   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
643   return data;
644 }
645 
646 template <typename NativeT>
CreateR3Parameter(const Array3D<NativeT> & array_3d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)647 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
648     const Array3D<NativeT>& array_3d, int64 parameter_number,
649     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
650   Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
651   if (use_bfloat16_ && literal.shape().element_type() == F32) {
652     literal = LiteralUtil::ConvertF32ToBF16(literal);
653   }
654   std::unique_ptr<GlobalData> data =
655       client_->TransferToServer(literal).ConsumeValueOrDie();
656   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
657   return data;
658 }
659 
660 template <typename NativeT>
CreateR4Parameter(const Array4D<NativeT> & array_4d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)661 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
662     const Array4D<NativeT>& array_4d, int64 parameter_number,
663     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
664   Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d);
665   if (use_bfloat16_ && literal.shape().element_type() == F32) {
666     literal = LiteralUtil::ConvertF32ToBF16(literal);
667   }
668   std::unique_ptr<GlobalData> data =
669       client_->TransferToServer(literal).ConsumeValueOrDie();
670   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
671   return data;
672 }
673 
674 template <typename NativeT>
CreateParameter(const Array<NativeT> & array,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)675 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
676     const Array<NativeT>& array, int64 parameter_number, const string& name,
677     XlaBuilder* builder, XlaOp* data_handle) {
678   Literal literal = LiteralUtil::CreateFromArray(array);
679   if (use_bfloat16_ && literal.shape().element_type() == F32) {
680     literal = LiteralUtil::ConvertF32ToBF16(literal);
681   }
682   std::unique_ptr<GlobalData> data =
683       client_->TransferToServer(literal).ConsumeValueOrDie();
684   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
685   return data;
686 }
687 
688 template <typename NativeT>
CreatePseudorandomR1(const int width,NativeT min_value,NativeT max_value,uint32 seed)689 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
690     const int width, NativeT min_value, NativeT max_value, uint32 seed) {
691   std::vector<NativeT> result(width);
692   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
693   for (int i = 0; i < width; ++i) {
694     result[i] = generator.get();
695   }
696   return result;
697 }
698 
699 template <typename NativeT>
CreatePseudorandomR2(const int rows,const int cols,NativeT min_value,NativeT max_value,uint32 seed)700 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
701     const int rows, const int cols, NativeT min_value, NativeT max_value,
702     uint32 seed) {
703   auto result = absl::make_unique<Array2D<NativeT>>(rows, cols);
704   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
705   for (int y = 0; y < rows; ++y) {
706     for (int x = 0; x < cols; ++x) {
707       (*result)(y, x) = generator.get();
708     }
709   }
710   return result;
711 }
712 
713 }  // namespace xla
714 
715 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
716