• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/global_data.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/bitmap.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 // Sets the use_bfloat16 on a container of test cases according to the values in
48 // use_bfloat16_params. Generates one set of test cases for each values in
49 // use_bfloat16_params with that value. Returns the result.
50 template <typename TestCase>
ExpandUseBfloat16(absl::Span<const bool> use_bfloat16_params,absl::Span<const TestCase> specs)51 std::vector<TestCase> ExpandUseBfloat16(
52     absl::Span<const bool> use_bfloat16_params,
53     absl::Span<const TestCase> specs) {
54   std::vector<TestCase> expanded;
55   for (bool use_bfloat16 : use_bfloat16_params) {
56     for (const auto& spec : specs) {
57       expanded.push_back(spec);
58       expanded.back().use_bfloat16 = use_bfloat16;
59     }
60   }
61   return expanded;
62 }
63 
64 // A client library test establishes an in-process XLA client connection.
65 class ClientLibraryTestBase : public ::testing::Test {
66  protected:
67   explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
68 
69   // Creates a new ClientLibraryTestBase with custom client options.
70   ClientLibraryTestBase(se::Platform* platform,
71                         const LocalClientOptions& client_options);
72 
73   // Returns the name of the test currently being run.
74   string TestName() const;
75 
SetFastMathDisabled(bool disabled)76   void SetFastMathDisabled(bool disabled) {
77     auto* opts = execution_options_.mutable_debug_options();
78     opts->set_xla_cpu_enable_fast_math(!disabled);
79     opts->set_xla_gpu_enable_fast_min_max(!disabled);
80   }
81 
SetSeed(uint64 seed)82   void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
83 
84   // Provides mutable access to the execution DebugOptions field; this lets
85   // tests tweak the options that will be used to compile/run the graph.
mutable_debug_options()86   DebugOptions* mutable_debug_options() {
87     return execution_options_.mutable_debug_options();
88   }
89 
90   // TODO(b/25566808): Add helper that populates a literal from a testdata file.
91 
92   // Convenience methods for building and running a computation with the member
93   // execution options. Modify execution_options_ in your test if you want to
94   // customize the options.
95   StatusOr<std::unique_ptr<GlobalData>> Execute(
96       XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
97 
98   StatusOr<Literal> ExecuteAndTransfer(
99       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
100       const Shape* shape_with_output_layout = nullptr);
101 
102   StatusOr<Literal> ExecuteAndTransfer(
103       const XlaComputation& computation,
104       absl::Span<GlobalData* const> arguments,
105       const Shape* shape_with_output_layout = nullptr);
106 
107   // This executes the computation via the reference client (which connects a
108   // interpreter backend). The result is used as the expected values of the
109   // computation.
110   StatusOr<Literal> ExecuteAndTransferReference(
111       const XlaComputation& computation,
112       absl::Span<GlobalData* const> arguments,
113       const Shape* shape_with_output_layout = nullptr);
114 
115   // Run a computation and return its value as a string. If an error
116   // occurs, then instead return the error as a string.
117   string ExecuteToString(XlaBuilder* builder,
118                          absl::Span<GlobalData* const> arguments);
119 
120   // Convenience methods for building and running a computation, transferring
121   // the result, and comparing it to the expected value(s). Methods are
122   // templated on the native host type which maps to specific XLA types (See
123   // XlaBuilder for details). For each rank, two forms are
124   // provided: one for floating point types with an ErrorSpec parameter, and one
125   // for integral types without the ErrorSpec parameter.
126   template <typename NativeT>
127   void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
128                            absl::Span<GlobalData* const> arguments);
129   template <typename NativeT>
130   void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
131                            absl::Span<GlobalData* const> arguments,
132                            ErrorSpec error);
133 
134   template <typename NativeT>
135   void ComputeAndCompareR1(XlaBuilder* builder,
136                            absl::Span<const NativeT> expected,
137                            absl::Span<GlobalData* const> arguments);
138   template <typename NativeT>
139   void ComputeAndCompareR1(XlaBuilder* builder,
140                            absl::Span<const NativeT> expected,
141                            absl::Span<GlobalData* const> arguments,
142                            ErrorSpec error);
143 
144   // As above, but uses a bitmap to hold the predicate vector to avoid
145   // deficiencies of vector<bool>.
146   void ComputeAndCompareR1(XlaBuilder* builder,
147                            const tensorflow::core::Bitmap& expected,
148                            absl::Span<GlobalData* const> arguments);
149 
150   template <typename NativeT>
151   void ComputeAndCompareR2(XlaBuilder* builder,
152                            const Array2D<NativeT>& expected,
153                            absl::Span<GlobalData* const> arguments);
154   template <typename NativeT>
155   void ComputeAndCompareR2(XlaBuilder* builder,
156                            const Array2D<NativeT>& expected,
157                            absl::Span<GlobalData* const> arguments,
158                            ErrorSpec error);
159 
160   template <typename NativeT>
161   void ComputeAndCompareR3(XlaBuilder* builder,
162                            const Array3D<NativeT>& expected,
163                            absl::Span<GlobalData* const> arguments);
164   template <typename NativeT>
165   void ComputeAndCompareR3(XlaBuilder* builder,
166                            const Array3D<NativeT>& expected,
167                            absl::Span<GlobalData* const> arguments,
168                            ErrorSpec error);
169 
170   template <typename NativeT>
171   void ComputeAndCompareR4(XlaBuilder* builder,
172                            const Array4D<NativeT>& expected,
173                            absl::Span<GlobalData* const> arguments);
174   template <typename NativeT>
175   void ComputeAndCompareR4(XlaBuilder* builder,
176                            const Array4D<NativeT>& expected,
177                            absl::Span<GlobalData* const> arguments,
178                            ErrorSpec error);
179 
180   // Build and run the computation and compare the result with the given
181   // literal. shape_with_layout indicates the result layout to request when
182   // calling Execute.
183   void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
184                                 absl::Span<GlobalData* const> arguments,
185                                 const Shape* shape_with_layout = nullptr);
186   void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
187                                 absl::Span<GlobalData* const> arguments,
188                                 ErrorSpec error,
189                                 const Shape* shape_with_layout = nullptr);
190 
191   // Build and run the computation and return the result as a literal.
192   // shape_with_layout indicates the result layout to request when calling
193   // Execute.
194   StatusOr<Literal> ComputeAndTransfer(
195       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
196       const Shape* shape_with_layout = nullptr);
197 
198   // ComputeAndCompare variant which returns an error status.
199   Status ComputeAndCompareLiteralWithStatus(
200       XlaBuilder* builder, const Literal& expected,
201       absl::Span<GlobalData* const> arguments,
202       const Shape* shape_with_layout = nullptr);
203   Status ComputeAndCompareLiteralWithStatus(
204       XlaBuilder* builder, const Literal& expected,
205       absl::Span<GlobalData* const> arguments, ErrorSpec error,
206       const Shape* shape_with_layout = nullptr);
207 
208   // Compare the result of the computation to a strings. In XLA strings are
209   // represented using rank-1 U8 shapes.
210   void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
211                              absl::Span<GlobalData* const> arguments);
212 
213   // Convenience method for running a built computation, transferring the
214   // result, and comparing it to the expected tuple literal.
215   void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
216                               absl::Span<GlobalData* const> arguments);
217   void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
218                               absl::Span<GlobalData* const> arguments,
219                               ErrorSpec error);
220 
221   // Convenience method for running a built computation and comparing the result
222   // with the reference result.
223   void ComputeAndCompare(XlaBuilder* builder,
224                          absl::Span<const Literal> arguments);
225   void ComputeAndCompare(XlaBuilder* builder,
226                          absl::Span<const Literal> arguments, ErrorSpec error);
227 
228   // Create scalar operations for use in reductions.
229   XlaComputation CreateScalarRelu();
230   XlaComputation CreateScalarMax();
231   XlaComputation CreateScalarReluSensitivity();
232 
233   // Special case convenience functions for creating filled arrays.
234 
235   // Creates an array of pseudorandom values lying between the given minimum and
236   // maximum values.
237   template <typename NativeT>
238   std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
239                                             NativeT max_value, uint32 seed);
240   template <typename NativeT>
241   std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
242                                                          const int cols,
243                                                          NativeT min_value,
244                                                          NativeT max_value,
245                                                          uint32 seed);
246 
247   // Creates a (rows x cols) array filled in the following form:
248   //
249   //  [      0              1 ...                   cols-1]
250   //  [  1,000          1,001 ...          1000.0 + cols-1]
251   //  [    ...            ... ...                      ...]
252   //  [(rows-1)*1000.0    ... ... (rows-1)*1000.0 + cols-1]
253   //
254   // If provided, offset is added uniformly to every element (e.g. an offset of
255   // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
256   std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
257                                                         const int cols,
258                                                         float offset = 0.0);
259 
260   // Creates a (rows x cols) array as above, padded out to
261   // (rows_padded x cols_padded) with zeroes.  Requires rows_padded >= rows
262   // and cols_padded > cols.
263   std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
264       const int rows, const int cols, const int rows_padded,
265       const int cols_padded);
266 
267   // Creates a parameter instruction, transfers the literal for the parameter to
268   // server, then stores into "data_handle" the global handle for that
269   // parameter. When the use_bfloat16 flag is set but the literal has F32
270   // elements, the literal will be converted to BF16 before being transferred.
271   std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
272       int64 parameter_number, const Literal& literal, const string& name,
273       XlaBuilder* builder, XlaOp* data_handle);
274 
275   // As above, but the caller can specify the device that the literal is
276   // transferred to. If device_handle is nullptr, the literal will be
277   // transferred to the default device.
278   std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
279       int64 parameter_number, const Literal& literal, const string& name,
280       const DeviceHandle* device_handle, XlaBuilder* builder,
281       XlaOp* data_handle);
282 
283   // Creates a parameter instruction and sets the value that will be passed to
284   // the computation as specified. This function must be used for all parameters
285   // or none and no parameters must be passed when invoking the computation if
286   // using this mechanism. If using this mechanism, then each parameter must be
287   // set exactly once. The first added parameter gets index 0, then 1 and so on.
288   XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
289 
290   template <class T>
AddParam(const Array<T> & argument,XlaBuilder * builder)291   XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
292     return AddParam(LiteralUtil::CreateFromArray(argument), builder);
293   }
294 
295   // Creates a constant instruction with the given literal. When the
296   // use_bfloat16 flag is set but the literal has F32 elements, the elements
297   // will be converted to BF16s.
298   XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
299 
300   // Creates a constant instruction with the given array. When the use_bfloat16
301   // flag is set but the array has float elements, the elements will be
302   // converted to bfloat16s.
303 
304   template <typename NativeT>
CreateConstantFromArray(const Array<NativeT> & array,XlaBuilder * builder)305   XlaOp CreateConstantFromArray(const Array<NativeT>& array,
306                                 XlaBuilder* builder) {
307     return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
308                                      builder);
309   }
310 
311   // Same as CreateConstantFromArray, but for scalars.
312   template <typename NativeT>
CreateConstantFromScalar(NativeT value,XlaBuilder * builder)313   XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
314     return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
315                                      builder);
316   }
317 
318   // Creates a parameter instruction that wraps a given value and then stores
319   // into "data_handle" the global handle for that parameter.
320   //
321   // "parameter_number" is the parameter number.
322   // "name" is the name of the parameter instruction.
323   //
324   // When the use_bfloat16 flag is set but NativeT is float, the data will be
325   // converted to bfloat16.
326   template <typename NativeT>
327   std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
328                                                 int64 parameter_number,
329                                                 const string& name,
330                                                 XlaBuilder* builder,
331                                                 XlaOp* data_handle);
332 
333   // Creates a parameter instruction that wraps the given values and then stores
334   // into "data_handle" the global handle for that parameter.
335   //
336   // "parameter_number" is the parameter number.
337   // "name" is the name of the parameter instruction.
338   //
339   // When the use_bfloat16 flag is set but NativeT is float, the data will be
340   // converted to bfloat16.
341   template <typename NativeT>
342   std::unique_ptr<GlobalData> CreateR1Parameter(
343       absl::Span<const NativeT> values, int64 parameter_number,
344       const string& name, XlaBuilder* builder, XlaOp* data_handle);
345 
346   // Creates a parameter instruction that wraps the given constant array
347   // "array_2d" and then stores to "data_handle" the global handle for that
348   // parameter.
349   //
350   // "parameter_number" is the parameter number.
351   // "name" is the name of the parameter instruction.
352   //
353   // When the use_bfloat16 flag is set but NativeT is float, the data will be
354   // converted to bfloat16.
355   template <typename NativeT>
356   std::unique_ptr<GlobalData> CreateR2Parameter(
357       const Array2D<NativeT>& array_2d, int64 parameter_number,
358       const string& name, XlaBuilder* builder, XlaOp* data_handle);
359 
360   // Creates a parameter instruction that wraps the given constant array
361   // "array_3d" and then stores to "data_handle" the global handle for that
362   // parameter.
363   //
364   // "parameter_number" is the parameter number.
365   // "name" is the name of the parameter instruction.
366   //
367   // When the use_bfloat16 flag is set but NativeT is float, the data will be
368   // converted to bfloat16.
369   template <typename NativeT>
370   std::unique_ptr<GlobalData> CreateR3Parameter(
371       const Array3D<NativeT>& array_3d, int64 parameter_number,
372       const string& name, XlaBuilder* builder, XlaOp* data_handle);
373 
374   // Getter and setter for the use_bfloat16 flag, which indicates whether to run
375   // tests with all float-type input/output converted to bfloat16.
use_bfloat16()376   bool use_bfloat16() const { return use_bfloat16_; }
set_use_bfloat16(bool value)377   void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
378 
379   // The float type used in this test, BF16 or F32 according to use_bfloat16.
FloatType()380   PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
381 
382   // Executes the computation and calculates the expected reference value using
383   // the reference client. Returns two literals in the order of (expected,
384   // actual).
385   StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
386       XlaBuilder* builder, absl::Span<const Literal> arguments);
387 
388   LocalClient* client_;
389   LocalClient* ref_client_;  // To compute reference result.
390   ExecutionOptions execution_options_;
391 
392  private:
393   Status ComputeAndCompareLiteralWithAllOutputLayouts(
394       const xla::XlaComputation& computation, const Literal& expected,
395       absl::Span<GlobalData* const> arguments,
396       const std::function<void(const Literal& actual,
397                                const string& error_message)>& verify_output);
398   Status ComputeAndCompareLiteralWithAllInputLayouts(
399       const xla::XlaComputation& computation, const Literal& expected,
400       absl::Span<GlobalData* const> arguments,
401       const std::function<void(const Literal& actual,
402                                const string& error_message)>& verify_output,
403       const Shape* output_with_layout = nullptr);
404 
405   // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true.
406   Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
407   Shape MaybeConvertShapeToBfloat16(const Shape& shape);
408 
409   // Whether to run tests with all float-type input/output converted to
410   // bfloat16.
411   bool use_bfloat16_ = false;
412 
413   // Arguments to be passed to the computation when it runs.
414   std::vector<Literal> arguments_;
415 };
416 
417 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments)418 void ClientLibraryTestBase::ComputeAndCompareR0(
419     XlaBuilder* builder, NativeT expected,
420     absl::Span<GlobalData* const> arguments) {
421   Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
422   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
423                                                   arguments);
424 }
425 
426 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)427 void ClientLibraryTestBase::ComputeAndCompareR0(
428     XlaBuilder* builder, NativeT expected,
429     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
430   static_assert(std::is_same<NativeT, float>::value ||
431                     std::is_same<NativeT, double>::value ||
432                     std::is_same<NativeT, bfloat16>::value ||
433                     std::is_same<NativeT, half>::value ||
434                     std::is_same<NativeT, complex64>::value ||
435                     std::is_same<NativeT, complex128>::value,
436                 "Float or complex type required when specifying an ErrorSpec");
437   Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
438   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
439                                                   arguments, error);
440 }
441 
442 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)443 void ClientLibraryTestBase::ComputeAndCompareR1(
444     XlaBuilder* builder, absl::Span<const NativeT> expected,
445     absl::Span<GlobalData* const> arguments) {
446   Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
447   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
448                                                   arguments);
449 }
450 
451 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)452 void ClientLibraryTestBase::ComputeAndCompareR1(
453     XlaBuilder* builder, absl::Span<const NativeT> expected,
454     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
455   static_assert(std::is_same<NativeT, float>::value ||
456                     std::is_same<NativeT, double>::value ||
457                     std::is_same<NativeT, bfloat16>::value ||
458                     std::is_same<NativeT, half>::value ||
459                     std::is_same<NativeT, complex64>::value ||
460                     std::is_same<NativeT, complex128>::value,
461                 "Float or complex type required when specifying an ErrorSpec");
462   Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
463   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
464                                                   arguments, error);
465 }
466 
467 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)468 void ClientLibraryTestBase::ComputeAndCompareR2(
469     XlaBuilder* builder, const Array2D<NativeT>& expected,
470     absl::Span<GlobalData* const> arguments) {
471   Literal expected_literal =
472       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
473   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
474                                                   arguments);
475 }
476 
477 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)478 void ClientLibraryTestBase::ComputeAndCompareR2(
479     XlaBuilder* builder, const Array2D<NativeT>& expected,
480     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
481   static_assert(std::is_same<NativeT, float>::value ||
482                     std::is_same<NativeT, double>::value ||
483                     std::is_same<NativeT, bfloat16>::value ||
484                     std::is_same<NativeT, half>::value ||
485                     std::is_same<NativeT, complex64>::value ||
486                     std::is_same<NativeT, complex128>::value,
487                 "Float or complex type required when specifying an ErrorSpec");
488   Literal expected_literal =
489       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
490   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
491                                                   arguments, error);
492 }
493 
494 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments)495 void ClientLibraryTestBase::ComputeAndCompareR3(
496     XlaBuilder* builder, const Array3D<NativeT>& expected,
497     absl::Span<GlobalData* const> arguments) {
498   Literal expected_literal =
499       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
500   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
501                                                   arguments);
502 }
503 
504 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)505 void ClientLibraryTestBase::ComputeAndCompareR3(
506     XlaBuilder* builder, const Array3D<NativeT>& expected,
507     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
508   static_assert(std::is_same<NativeT, float>::value ||
509                     std::is_same<NativeT, double>::value ||
510                     std::is_same<NativeT, bfloat16>::value ||
511                     std::is_same<NativeT, half>::value ||
512                     std::is_same<NativeT, complex64>::value ||
513                     std::is_same<NativeT, complex128>::value,
514                 "Float or complex type required when specifying an ErrorSpec");
515   Literal expected_literal =
516       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
517   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
518                                                   arguments, error);
519 }
520 
521 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments)522 void ClientLibraryTestBase::ComputeAndCompareR4(
523     XlaBuilder* builder, const Array4D<NativeT>& expected,
524     absl::Span<GlobalData* const> arguments) {
525   Literal expected_literal =
526       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
527   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
528                                                   arguments);
529 }
530 
531 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)532 void ClientLibraryTestBase::ComputeAndCompareR4(
533     XlaBuilder* builder, const Array4D<NativeT>& expected,
534     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
535   static_assert(std::is_same<NativeT, float>::value ||
536                     std::is_same<NativeT, double>::value ||
537                     std::is_same<NativeT, bfloat16>::value ||
538                     std::is_same<NativeT, half>::value ||
539                     std::is_same<NativeT, complex64>::value ||
540                     std::is_same<NativeT, complex128>::value,
541                 "Float or complex type required when specifying an ErrorSpec");
542   Literal expected_literal =
543       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
544   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
545                                                   arguments, error);
546 }
547 
548 template <typename NativeT>
CreateR0Parameter(NativeT value,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)549 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
550     NativeT value, int64 parameter_number, const string& name,
551     XlaBuilder* builder, XlaOp* data_handle) {
552   Literal literal = LiteralUtil::CreateR0(value);
553   if (use_bfloat16_ && literal.shape().element_type() == F32) {
554     literal = LiteralUtil::ConvertF32ToBF16(literal);
555   }
556   std::unique_ptr<GlobalData> data =
557       client_->TransferToServer(literal).ConsumeValueOrDie();
558   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
559   return data;
560 }
561 
562 template <typename NativeT>
CreateR1Parameter(absl::Span<const NativeT> values,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)563 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
564     absl::Span<const NativeT> values, int64 parameter_number,
565     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
566   Literal literal = LiteralUtil::CreateR1(values);
567   if (use_bfloat16_ && literal.shape().element_type() == F32) {
568     literal = LiteralUtil::ConvertF32ToBF16(literal);
569   }
570   std::unique_ptr<GlobalData> data =
571       client_->TransferToServer(literal).ConsumeValueOrDie();
572   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
573   return data;
574 }
575 
576 template <typename NativeT>
CreateR2Parameter(const Array2D<NativeT> & array_2d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)577 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
578     const Array2D<NativeT>& array_2d, int64 parameter_number,
579     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
580   Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
581   if (use_bfloat16_ && literal.shape().element_type() == F32) {
582     literal = LiteralUtil::ConvertF32ToBF16(literal);
583   }
584   std::unique_ptr<GlobalData> data =
585       client_->TransferToServer(literal).ConsumeValueOrDie();
586   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
587   return data;
588 }
589 
590 template <typename NativeT>
CreateR3Parameter(const Array3D<NativeT> & array_3d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)591 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
592     const Array3D<NativeT>& array_3d, int64 parameter_number,
593     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
594   Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
595   if (use_bfloat16_ && literal.shape().element_type() == F32) {
596     literal = LiteralUtil::ConvertF32ToBF16(literal);
597   }
598   std::unique_ptr<GlobalData> data =
599       client_->TransferToServer(literal).ConsumeValueOrDie();
600   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
601   return data;
602 }
603 
604 template <typename NativeT>
CreatePseudorandomR1(const int width,NativeT min_value,NativeT max_value,uint32 seed)605 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
606     const int width, NativeT min_value, NativeT max_value, uint32 seed) {
607   std::vector<NativeT> result(width);
608   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
609   for (int i = 0; i < width; ++i) {
610     result[i] = generator.get();
611   }
612   return result;
613 }
614 
615 template <typename NativeT>
CreatePseudorandomR2(const int rows,const int cols,NativeT min_value,NativeT max_value,uint32 seed)616 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
617     const int rows, const int cols, NativeT min_value, NativeT max_value,
618     uint32 seed) {
619   auto result = absl::make_unique<Array2D<NativeT>>(rows, cols);
620   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
621   for (int y = 0; y < rows; ++y) {
622     for (int x = 0; x < cols; ++x) {
623       (*result)(y, x) = generator.get();
624     }
625   }
626   return result;
627 }
628 
629 }  // namespace xla
630 
631 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
632