1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
18
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/global_data.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
38 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
39 #include "tensorflow/compiler/xla/tests/test_utils.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/bitmap.h"
42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/types.h"
45
46 namespace xla {
47
48 // Sets the use_bfloat16 on a container of test cases according to the values in
49 // use_bfloat16_params. Generates one set of test cases for each values in
50 // use_bfloat16_params with that value. Returns the result.
51 template <typename TestCase>
ExpandUseBfloat16(absl::Span<const bool> use_bfloat16_params,absl::Span<const TestCase> specs)52 std::vector<TestCase> ExpandUseBfloat16(
53 absl::Span<const bool> use_bfloat16_params,
54 absl::Span<const TestCase> specs) {
55 std::vector<TestCase> expanded;
56 for (bool use_bfloat16 : use_bfloat16_params) {
57 for (const auto& spec : specs) {
58 expanded.push_back(spec);
59 expanded.back().use_bfloat16 = use_bfloat16;
60 }
61 }
62 return expanded;
63 }
64
65 // A client library test establishes an in-process XLA client connection.
66 class ClientLibraryTestBase : public ManifestCheckingTest {
67 protected:
68 explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
69
70 // Creates a new ClientLibraryTestBase with custom client options.
71 ClientLibraryTestBase(se::Platform* platform,
72 const LocalClientOptions& client_options);
73
74 // Returns the name of the test currently being run.
75 string TestName() const;
76
SetFastMathDisabled(bool disabled)77 void SetFastMathDisabled(bool disabled) {
78 auto* opts = execution_options_.mutable_debug_options();
79 opts->set_xla_cpu_enable_fast_math(!disabled);
80 opts->set_xla_cpu_enable_fast_min_max(!disabled);
81 opts->set_xla_gpu_enable_fast_min_max(!disabled);
82 }
83
SetSeed(uint64 seed)84 void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
85
86 // Provides mutable access to the execution DebugOptions field; this lets
87 // tests tweak the options that will be used to compile/run the graph.
mutable_debug_options()88 DebugOptions* mutable_debug_options() {
89 return execution_options_.mutable_debug_options();
90 }
91
92 // TODO(b/25566808): Add helper that populates a literal from a testdata file.
93
94 // Convenience methods for building and running a computation with the member
95 // execution options. Modify execution_options_ in your test if you want to
96 // customize the options.
97 StatusOr<std::unique_ptr<GlobalData>> Execute(
98 XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
99
100 StatusOr<Literal> ExecuteAndTransfer(
101 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
102 const Shape* shape_with_output_layout = nullptr);
103
104 StatusOr<Literal> ExecuteAndTransfer(
105 const XlaComputation& computation,
106 absl::Span<GlobalData* const> arguments,
107 const Shape* shape_with_output_layout = nullptr);
108
109 // This executes the computation via the reference client (which connects a
110 // interpreter backend). The result is used as the expected value of the
111 // computation.
112 StatusOr<Literal> ExecuteAndTransferReference(
113 const XlaComputation& computation,
114 absl::Span<GlobalData* const> arguments,
115 const Shape* shape_with_output_layout = nullptr);
116
117 // Run a computation and return its value as a string. If an error
118 // occurs, then instead return the error as a string.
119 string ExecuteToString(XlaBuilder* builder,
120 absl::Span<GlobalData* const> arguments);
121
122 // Convenience methods for building and running a computation, transferring
123 // the result, and comparing it to the expected value(s). Methods are
124 // templated on the native host type which maps to specific XLA types (See
125 // XlaBuilder for details). For each rank, two forms are
126 // provided: one for floating point types with an ErrorSpec parameter, and one
127 // for integral types without the ErrorSpec parameter.
128 template <typename NativeT>
129 void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
130 absl::Span<GlobalData* const> arguments);
131 template <typename NativeT>
132 void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
133 absl::Span<GlobalData* const> arguments,
134 ErrorSpec error);
135
136 template <typename NativeT>
137 void ComputeAndCompareR1(XlaBuilder* builder,
138 absl::Span<const NativeT> expected,
139 absl::Span<GlobalData* const> arguments);
140 template <typename NativeT>
141 void ComputeAndCompareR1(XlaBuilder* builder,
142 absl::Span<const NativeT> expected,
143 absl::Span<GlobalData* const> arguments,
144 ErrorSpec error);
145
146 // As above, but uses a bitmap to hold the predicate vector to avoid
147 // deficiencies of vector<bool>.
148 void ComputeAndCompareR1(XlaBuilder* builder,
149 const tensorflow::core::Bitmap& expected,
150 absl::Span<GlobalData* const> arguments);
151
152 template <typename NativeT>
153 void ComputeAndCompareR2(XlaBuilder* builder,
154 const Array2D<NativeT>& expected,
155 absl::Span<GlobalData* const> arguments);
156 template <typename NativeT>
157 void ComputeAndCompareR2(XlaBuilder* builder,
158 const Array2D<NativeT>& expected,
159 absl::Span<GlobalData* const> arguments,
160 ErrorSpec error);
161
162 template <typename NativeT>
163 void ComputeAndCompareR3(XlaBuilder* builder,
164 const Array3D<NativeT>& expected,
165 absl::Span<GlobalData* const> arguments);
166 template <typename NativeT>
167 void ComputeAndCompareR3(XlaBuilder* builder,
168 const Array3D<NativeT>& expected,
169 absl::Span<GlobalData* const> arguments,
170 ErrorSpec error);
171
172 template <typename NativeT>
173 void ComputeAndCompareR4(XlaBuilder* builder,
174 const Array4D<NativeT>& expected,
175 absl::Span<GlobalData* const> arguments);
176 template <typename NativeT>
177 void ComputeAndCompareR4(XlaBuilder* builder,
178 const Array4D<NativeT>& expected,
179 absl::Span<GlobalData* const> arguments,
180 ErrorSpec error);
181
182 // Build and run the computation and compare the result with the given
183 // literal. shape_with_layout indicates the result layout to request when
184 // calling Execute.
185 void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
186 absl::Span<GlobalData* const> arguments,
187 const Shape* shape_with_layout = nullptr);
188 void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
189 absl::Span<GlobalData* const> arguments,
190 ErrorSpec error,
191 const Shape* shape_with_layout = nullptr);
192
193 // Build and run the computation and return the result as a literal.
194 // shape_with_layout indicates the result layout to request when calling
195 // Execute.
196 StatusOr<Literal> ComputeAndTransfer(
197 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
198 const Shape* shape_with_layout = nullptr);
199
200 // ComputeAndCompare variant which returns an error status.
201 Status ComputeAndCompareLiteralWithStatus(
202 XlaBuilder* builder, const Literal& expected,
203 absl::Span<GlobalData* const> arguments,
204 const Shape* shape_with_layout = nullptr);
205 Status ComputeAndCompareLiteralWithStatus(
206 XlaBuilder* builder, const Literal& expected,
207 absl::Span<GlobalData* const> arguments, ErrorSpec error,
208 const Shape* shape_with_layout = nullptr);
209
210 // Compare the result of the computation to a strings. In XLA strings are
211 // represented using rank-1 U8 shapes.
212 void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
213 absl::Span<GlobalData* const> arguments);
214
215 // Convenience method for running a built computation, transferring the
216 // result, and comparing it to the expected tuple literal.
217 void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
218 absl::Span<GlobalData* const> arguments);
219 void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
220 absl::Span<GlobalData* const> arguments,
221 ErrorSpec error);
222
223 // Convenience method for running a built computation and comparing the result
224 // with the reference result.
225 void ComputeAndCompare(XlaBuilder* builder,
226 absl::Span<const Literal> arguments);
227 void ComputeAndCompare(XlaBuilder* builder,
228 absl::Span<const Literal> arguments, ErrorSpec error);
229 template <typename NativeT>
230 void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
231 absl::Span<GlobalData* const> arguments);
232 template <typename NativeT>
233 void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
234 absl::Span<GlobalData* const> arguments,
235 ErrorSpec error);
236 // Create scalar operations for use in reductions.
237 XlaComputation CreateScalarRelu();
238 XlaComputation CreateScalarMax();
239 XlaComputation CreateScalarReluSensitivity();
240
241 // Special case convenience functions for creating filled arrays.
242
243 // Creates an array of pseudorandom values lying between the given minimum and
244 // maximum values.
245 template <typename NativeT>
246 std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
247 NativeT max_value, uint32 seed);
248 template <typename NativeT>
249 std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
250 const int cols,
251 NativeT min_value,
252 NativeT max_value,
253 uint32 seed);
254
255 // Creates a (rows x cols) array filled in the following form:
256 //
257 // [ 0 1 ... cols-1]
258 // [ 1,000 1,001 ... 1000.0 + cols-1]
259 // [ ... ... ... ...]
260 // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1]
261 //
262 // If provided, offset is added uniformly to every element (e.g. an offset of
263 // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
264 std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
265 const int cols,
266 float offset = 0.0);
267
268 // Creates a (rows x cols) array as above, padded out to
269 // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows
270 // and cols_padded > cols.
271 std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
272 const int rows, const int cols, const int rows_padded,
273 const int cols_padded);
274
275 // Creates a parameter instruction, transfers the literal for the parameter to
276 // server, then stores into "data_handle" the global handle for that
277 // parameter. When the use_bfloat16 flag is set but the literal has F32
278 // elements, the literal will be converted to BF16 before being transferred.
279 StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
280 int64 parameter_number, const Literal& literal, const string& name,
281 XlaBuilder* builder, XlaOp* data_handle);
282
283 // As above, but the caller can specify the device that the literal is
284 // transferred to. If device_handle is nullptr, the literal will be
285 // transferred to the default device.
286 StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
287 int64 parameter_number, const Literal& literal, const string& name,
288 const DeviceHandle* device_handle, XlaBuilder* builder,
289 XlaOp* data_handle);
290
291 // Creates a parameter instruction and sets the value that will be passed to
292 // the computation as specified. This function must be used for all parameters
293 // or none and no parameters must be passed when invoking the computation if
294 // using this mechanism. If using this mechanism, then each parameter must be
295 // set exactly once. The first added parameter gets index 0, then 1 and so on.
296 XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
297
298 template <class T>
AddParam(const Array<T> & argument,XlaBuilder * builder)299 XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
300 return AddParam(LiteralUtil::CreateFromArray(argument), builder);
301 }
302
303 // Creates a constant instruction with the given literal. When the
304 // use_bfloat16 flag is set but the literal has F32 elements, the elements
305 // will be converted to BF16s.
306 XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
307
308 // Creates a constant instruction with the given array. When the use_bfloat16
309 // flag is set but the array has float elements, the elements will be
310 // converted to bfloat16s.
311
312 template <typename NativeT>
CreateConstantFromArray(const Array<NativeT> & array,XlaBuilder * builder)313 XlaOp CreateConstantFromArray(const Array<NativeT>& array,
314 XlaBuilder* builder) {
315 return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
316 builder);
317 }
318
319 // Same as CreateConstantFromArray, but for scalars.
320 template <typename NativeT>
CreateConstantFromScalar(NativeT value,XlaBuilder * builder)321 XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
322 return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
323 builder);
324 }
325
326 // Creates a parameter instruction that wraps a given value and then stores
327 // into "data_handle" the global handle for that parameter.
328 //
329 // "parameter_number" is the parameter number.
330 // "name" is the name of the parameter instruction.
331 //
332 // When the use_bfloat16 flag is set but NativeT is float, the data will be
333 // converted to bfloat16.
334 template <typename NativeT>
335 std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
336 int64 parameter_number,
337 const string& name,
338 XlaBuilder* builder,
339 XlaOp* data_handle);
340
341 // Creates a parameter instruction that wraps the given values and then stores
342 // into "data_handle" the global handle for that parameter.
343 //
344 // "parameter_number" is the parameter number.
345 // "name" is the name of the parameter instruction.
346 //
347 // When the use_bfloat16 flag is set but NativeT is float, the data will be
348 // converted to bfloat16.
349 template <typename NativeT>
350 std::unique_ptr<GlobalData> CreateR1Parameter(
351 absl::Span<const NativeT> values, int64 parameter_number,
352 const string& name, XlaBuilder* builder, XlaOp* data_handle);
353
354 // Creates a parameter instruction that wraps the given constant array
355 // "array_2d" and then stores it to the global handle for that parameter
356 // "data_handle".
357 //
358 // "parameter_number" is the parameter number.
359 // "name" is the name of the parameter instruction.
360 //
361 // When the use_bfloat16 flag is set but NativeT is float, the data will be
362 // converted to bfloat16.
363 template <typename NativeT>
364 std::unique_ptr<GlobalData> CreateR2Parameter(
365 const Array2D<NativeT>& array_2d, int64 parameter_number,
366 const string& name, XlaBuilder* builder, XlaOp* data_handle);
367
368 // Creates a parameter instruction that wraps the given constant array
369 // "array_3d" and then stores it to the global handle for that parameter
370 // "data_handle".
371 //
372 // "parameter_number" is the parameter number.
373 // "name" is the name of the parameter instruction.
374 //
375 // When the use_bfloat16 flag is set but NativeT is float, the data will be
376 // converted to bfloat16.
377 template <typename NativeT>
378 std::unique_ptr<GlobalData> CreateR3Parameter(
379 const Array3D<NativeT>& array_3d, int64 parameter_number,
380 const string& name, XlaBuilder* builder, XlaOp* data_handle);
381
382 // Creates a parameter instruction that wraps the given constant array
383 // "array_4d" and then stores it to the global handle for that parameter
384 // "data_handle".
385 //
386 // "parameter_number" is the parameter number.
387 // "name" is the name of the parameter instruction.
388 //
389 // When the use_bfloat16 flag is set but NativeT is float, the data will be
390 // converted to bfloat16.
391 template <typename NativeT>
392 std::unique_ptr<GlobalData> CreateR4Parameter(
393 const Array4D<NativeT>& array_4d, int64 parameter_number,
394 const string& name, XlaBuilder* builder, XlaOp* data_handle);
395
396 template <typename NativeT>
397 std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
398 int64 parameter_number,
399 const string& name,
400 XlaBuilder* builder,
401 XlaOp* data_handle);
402
403 // Getter and setter for the use_bfloat16 flag, which indicates whether to run
404 // tests with all float-type input/output converted to bfloat16.
use_bfloat16()405 bool use_bfloat16() const { return use_bfloat16_; }
set_use_bfloat16(bool value)406 void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
407
408 // The float type used in this test, BF16 or F32 according to use_bfloat16.
FloatType()409 PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
410
411 // Executes the computation and calculates the expected reference value using
412 // the reference client. Returns two literals in the order of (expected,
413 // actual).
414 StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
415 XlaBuilder* builder, absl::Span<const Literal> arguments);
416
417 // Converts an f32 literal to bf16 if use_bfloat16_ is true.
418 Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
419
420 LocalClient* client_;
421 LocalClient* ref_client_; // To compute reference result.
422 ExecutionOptions execution_options_;
423
424 private:
425 Status ComputeAndCompareLiteralWithAllOutputLayouts(
426 const xla::XlaComputation& computation, const Literal& expected,
427 absl::Span<GlobalData* const> arguments,
428 const std::function<void(const Literal& actual,
429 const string& error_message)>& verify_output);
430 Status ComputeAndCompareLiteralWithAllInputLayouts(
431 const xla::XlaComputation& computation, const Literal& expected,
432 absl::Span<GlobalData* const> arguments,
433 const std::function<void(const Literal& actual,
434 const string& error_message)>& verify_output,
435 const Shape* output_with_layout = nullptr);
436
437 // Converts an f32 shape to bf16 if use_bfloat16_ is true.
438 Shape MaybeConvertShapeToBfloat16(const Shape& shape);
439
440 // Whether to run tests with all float-type input/output converted to
441 // bfloat16.
442 bool use_bfloat16_ = false;
443
444 // Arguments to be passed to the computation when it runs.
445 std::vector<Literal> arguments_;
446 };
447
448 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments)449 void ClientLibraryTestBase::ComputeAndCompareR0(
450 XlaBuilder* builder, NativeT expected,
451 absl::Span<GlobalData* const> arguments) {
452 Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
453 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
454 arguments);
455 }
456
457 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)458 void ClientLibraryTestBase::ComputeAndCompareR0(
459 XlaBuilder* builder, NativeT expected,
460 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
461 static_assert(std::is_same<NativeT, float>::value ||
462 std::is_same<NativeT, double>::value ||
463 std::is_same<NativeT, bfloat16>::value ||
464 std::is_same<NativeT, half>::value ||
465 std::is_same<NativeT, complex64>::value ||
466 std::is_same<NativeT, complex128>::value,
467 "Float or complex type required when specifying an ErrorSpec");
468 Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
469 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
470 arguments, error);
471 }
472
473 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)474 void ClientLibraryTestBase::ComputeAndCompareR1(
475 XlaBuilder* builder, absl::Span<const NativeT> expected,
476 absl::Span<GlobalData* const> arguments) {
477 Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
478 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
479 arguments);
480 }
481
482 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)483 void ClientLibraryTestBase::ComputeAndCompareR1(
484 XlaBuilder* builder, absl::Span<const NativeT> expected,
485 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
486 static_assert(std::is_same<NativeT, float>::value ||
487 std::is_same<NativeT, double>::value ||
488 std::is_same<NativeT, bfloat16>::value ||
489 std::is_same<NativeT, half>::value ||
490 std::is_same<NativeT, complex64>::value ||
491 std::is_same<NativeT, complex128>::value,
492 "Float or complex type required when specifying an ErrorSpec");
493 Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
494 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
495 arguments, error);
496 }
497
498 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)499 void ClientLibraryTestBase::ComputeAndCompareR2(
500 XlaBuilder* builder, const Array2D<NativeT>& expected,
501 absl::Span<GlobalData* const> arguments) {
502 Literal expected_literal =
503 LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
504 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
505 arguments);
506 }
507
508 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)509 void ClientLibraryTestBase::ComputeAndCompareR2(
510 XlaBuilder* builder, const Array2D<NativeT>& expected,
511 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
512 static_assert(std::is_same<NativeT, float>::value ||
513 std::is_same<NativeT, double>::value ||
514 std::is_same<NativeT, bfloat16>::value ||
515 std::is_same<NativeT, half>::value ||
516 std::is_same<NativeT, complex64>::value ||
517 std::is_same<NativeT, complex128>::value,
518 "Float or complex type required when specifying an ErrorSpec");
519 Literal expected_literal =
520 LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
521 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
522 arguments, error);
523 }
524
525 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments)526 void ClientLibraryTestBase::ComputeAndCompareR3(
527 XlaBuilder* builder, const Array3D<NativeT>& expected,
528 absl::Span<GlobalData* const> arguments) {
529 Literal expected_literal =
530 LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
531 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
532 arguments);
533 }
534
535 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)536 void ClientLibraryTestBase::ComputeAndCompareR3(
537 XlaBuilder* builder, const Array3D<NativeT>& expected,
538 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
539 static_assert(std::is_same<NativeT, float>::value ||
540 std::is_same<NativeT, double>::value ||
541 std::is_same<NativeT, bfloat16>::value ||
542 std::is_same<NativeT, half>::value ||
543 std::is_same<NativeT, complex64>::value ||
544 std::is_same<NativeT, complex128>::value,
545 "Float or complex type required when specifying an ErrorSpec");
546 Literal expected_literal =
547 LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
548 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
549 arguments, error);
550 }
551
552 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments)553 void ClientLibraryTestBase::ComputeAndCompareR4(
554 XlaBuilder* builder, const Array4D<NativeT>& expected,
555 absl::Span<GlobalData* const> arguments) {
556 Literal expected_literal =
557 LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
558 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
559 arguments);
560 }
561
562 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)563 void ClientLibraryTestBase::ComputeAndCompareR4(
564 XlaBuilder* builder, const Array4D<NativeT>& expected,
565 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
566 static_assert(std::is_same<NativeT, float>::value ||
567 std::is_same<NativeT, double>::value ||
568 std::is_same<NativeT, bfloat16>::value ||
569 std::is_same<NativeT, half>::value ||
570 std::is_same<NativeT, complex64>::value ||
571 std::is_same<NativeT, complex128>::value,
572 "Float or complex type required when specifying an ErrorSpec");
573 Literal expected_literal =
574 LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
575 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
576 arguments, error);
577 }
578
579 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments)580 void ClientLibraryTestBase::ComputeAndCompare(
581 XlaBuilder* builder, const Array<NativeT>& expected,
582 absl::Span<GlobalData* const> arguments) {
583 Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
584 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
585 arguments);
586 }
587
588 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)589 void ClientLibraryTestBase::ComputeAndCompare(
590 XlaBuilder* builder, const Array<NativeT>& expected,
591 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
592 static_assert(std::is_same<NativeT, float>::value ||
593 std::is_same<NativeT, double>::value ||
594 std::is_same<NativeT, bfloat16>::value ||
595 std::is_same<NativeT, half>::value ||
596 std::is_same<NativeT, complex64>::value ||
597 std::is_same<NativeT, complex128>::value,
598 "Float or complex type required when specifying an ErrorSpec");
599 Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
600 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
601 arguments, error);
602 }
603
604 template <typename NativeT>
CreateR0Parameter(NativeT value,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)605 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
606 NativeT value, int64 parameter_number, const string& name,
607 XlaBuilder* builder, XlaOp* data_handle) {
608 Literal literal = LiteralUtil::CreateR0(value);
609 if (use_bfloat16_ && literal.shape().element_type() == F32) {
610 literal = LiteralUtil::ConvertF32ToBF16(literal);
611 }
612 std::unique_ptr<GlobalData> data =
613 client_->TransferToServer(literal).ConsumeValueOrDie();
614 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
615 return data;
616 }
617
618 template <typename NativeT>
CreateR1Parameter(absl::Span<const NativeT> values,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)619 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
620 absl::Span<const NativeT> values, int64 parameter_number,
621 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
622 Literal literal = LiteralUtil::CreateR1(values);
623 if (use_bfloat16_ && literal.shape().element_type() == F32) {
624 literal = LiteralUtil::ConvertF32ToBF16(literal);
625 }
626 std::unique_ptr<GlobalData> data =
627 client_->TransferToServer(literal).ConsumeValueOrDie();
628 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
629 return data;
630 }
631
632 template <typename NativeT>
CreateR2Parameter(const Array2D<NativeT> & array_2d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)633 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
634 const Array2D<NativeT>& array_2d, int64 parameter_number,
635 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
636 Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
637 if (use_bfloat16_ && literal.shape().element_type() == F32) {
638 literal = LiteralUtil::ConvertF32ToBF16(literal);
639 }
640 std::unique_ptr<GlobalData> data =
641 client_->TransferToServer(literal).ConsumeValueOrDie();
642 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
643 return data;
644 }
645
646 template <typename NativeT>
CreateR3Parameter(const Array3D<NativeT> & array_3d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)647 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
648 const Array3D<NativeT>& array_3d, int64 parameter_number,
649 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
650 Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
651 if (use_bfloat16_ && literal.shape().element_type() == F32) {
652 literal = LiteralUtil::ConvertF32ToBF16(literal);
653 }
654 std::unique_ptr<GlobalData> data =
655 client_->TransferToServer(literal).ConsumeValueOrDie();
656 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
657 return data;
658 }
659
660 template <typename NativeT>
CreateR4Parameter(const Array4D<NativeT> & array_4d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)661 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
662 const Array4D<NativeT>& array_4d, int64 parameter_number,
663 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
664 Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d);
665 if (use_bfloat16_ && literal.shape().element_type() == F32) {
666 literal = LiteralUtil::ConvertF32ToBF16(literal);
667 }
668 std::unique_ptr<GlobalData> data =
669 client_->TransferToServer(literal).ConsumeValueOrDie();
670 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
671 return data;
672 }
673
674 template <typename NativeT>
CreateParameter(const Array<NativeT> & array,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)675 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
676 const Array<NativeT>& array, int64 parameter_number, const string& name,
677 XlaBuilder* builder, XlaOp* data_handle) {
678 Literal literal = LiteralUtil::CreateFromArray(array);
679 if (use_bfloat16_ && literal.shape().element_type() == F32) {
680 literal = LiteralUtil::ConvertF32ToBF16(literal);
681 }
682 std::unique_ptr<GlobalData> data =
683 client_->TransferToServer(literal).ConsumeValueOrDie();
684 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
685 return data;
686 }
687
688 template <typename NativeT>
CreatePseudorandomR1(const int width,NativeT min_value,NativeT max_value,uint32 seed)689 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
690 const int width, NativeT min_value, NativeT max_value, uint32 seed) {
691 std::vector<NativeT> result(width);
692 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
693 for (int i = 0; i < width; ++i) {
694 result[i] = generator.get();
695 }
696 return result;
697 }
698
699 template <typename NativeT>
CreatePseudorandomR2(const int rows,const int cols,NativeT min_value,NativeT max_value,uint32 seed)700 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
701 const int rows, const int cols, NativeT min_value, NativeT max_value,
702 uint32 seed) {
703 auto result = absl::make_unique<Array2D<NativeT>>(rows, cols);
704 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
705 for (int y = 0; y < rows; ++y) {
706 for (int x = 0; x < cols; ++x) {
707 (*result)(y, x) = generator.get();
708 }
709 }
710 return result;
711 }
712
713 } // namespace xla
714
715 #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
716