1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
18
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/global_data.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/literal_util.h"
36 #include "tensorflow/compiler/xla/statusor.h"
37 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/bitmap.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46
47 // Sets the use_bfloat16 on a container of test cases according to the values in
48 // use_bfloat16_params. Generates one set of test cases for each values in
49 // use_bfloat16_params with that value. Returns the result.
50 template <typename TestCase>
ExpandUseBfloat16(absl::Span<const bool> use_bfloat16_params,absl::Span<const TestCase> specs)51 std::vector<TestCase> ExpandUseBfloat16(
52 absl::Span<const bool> use_bfloat16_params,
53 absl::Span<const TestCase> specs) {
54 std::vector<TestCase> expanded;
55 for (bool use_bfloat16 : use_bfloat16_params) {
56 for (const auto& spec : specs) {
57 expanded.push_back(spec);
58 expanded.back().use_bfloat16 = use_bfloat16;
59 }
60 }
61 return expanded;
62 }
63
64 // A client library test establishes an in-process XLA client connection.
65 class ClientLibraryTestBase : public ::testing::Test {
66 protected:
67 explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
68
69 // Creates a new ClientLibraryTestBase with custom client options.
70 ClientLibraryTestBase(se::Platform* platform,
71 const LocalClientOptions& client_options);
72
73 // Returns the name of the test currently being run.
74 string TestName() const;
75
SetFastMathDisabled(bool disabled)76 void SetFastMathDisabled(bool disabled) {
77 auto* opts = execution_options_.mutable_debug_options();
78 opts->set_xla_cpu_enable_fast_math(!disabled);
79 opts->set_xla_gpu_enable_fast_min_max(!disabled);
80 }
81
SetSeed(uint64 seed)82 void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
83
84 // Provides mutable access to the execution DebugOptions field; this lets
85 // tests tweak the options that will be used to compile/run the graph.
mutable_debug_options()86 DebugOptions* mutable_debug_options() {
87 return execution_options_.mutable_debug_options();
88 }
89
90 // TODO(b/25566808): Add helper that populates a literal from a testdata file.
91
92 // Convenience methods for building and running a computation with the member
93 // execution options. Modify execution_options_ in your test if you want to
94 // customize the options.
95 StatusOr<std::unique_ptr<GlobalData>> Execute(
96 XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
97
98 StatusOr<Literal> ExecuteAndTransfer(
99 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
100 const Shape* shape_with_output_layout = nullptr);
101
102 StatusOr<Literal> ExecuteAndTransfer(
103 const XlaComputation& computation,
104 absl::Span<GlobalData* const> arguments,
105 const Shape* shape_with_output_layout = nullptr);
106
107 // This executes the computation via the reference client (which connects a
108 // interpreter backend). The result is used as the expected values of the
109 // computation.
110 StatusOr<Literal> ExecuteAndTransferReference(
111 const XlaComputation& computation,
112 absl::Span<GlobalData* const> arguments,
113 const Shape* shape_with_output_layout = nullptr);
114
115 // Run a computation and return its value as a string. If an error
116 // occurs, then instead return the error as a string.
117 string ExecuteToString(XlaBuilder* builder,
118 absl::Span<GlobalData* const> arguments);
119
120 // Convenience methods for building and running a computation, transferring
121 // the result, and comparing it to the expected value(s). Methods are
122 // templated on the native host type which maps to specific XLA types (See
123 // XlaBuilder for details). For each rank, two forms are
124 // provided: one for floating point types with an ErrorSpec parameter, and one
125 // for integral types without the ErrorSpec parameter.
126 template <typename NativeT>
127 void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
128 absl::Span<GlobalData* const> arguments);
129 template <typename NativeT>
130 void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
131 absl::Span<GlobalData* const> arguments,
132 ErrorSpec error);
133
134 template <typename NativeT>
135 void ComputeAndCompareR1(XlaBuilder* builder,
136 absl::Span<const NativeT> expected,
137 absl::Span<GlobalData* const> arguments);
138 template <typename NativeT>
139 void ComputeAndCompareR1(XlaBuilder* builder,
140 absl::Span<const NativeT> expected,
141 absl::Span<GlobalData* const> arguments,
142 ErrorSpec error);
143
144 // As above, but uses a bitmap to hold the predicate vector to avoid
145 // deficiencies of vector<bool>.
146 void ComputeAndCompareR1(XlaBuilder* builder,
147 const tensorflow::core::Bitmap& expected,
148 absl::Span<GlobalData* const> arguments);
149
150 template <typename NativeT>
151 void ComputeAndCompareR2(XlaBuilder* builder,
152 const Array2D<NativeT>& expected,
153 absl::Span<GlobalData* const> arguments);
154 template <typename NativeT>
155 void ComputeAndCompareR2(XlaBuilder* builder,
156 const Array2D<NativeT>& expected,
157 absl::Span<GlobalData* const> arguments,
158 ErrorSpec error);
159
160 template <typename NativeT>
161 void ComputeAndCompareR3(XlaBuilder* builder,
162 const Array3D<NativeT>& expected,
163 absl::Span<GlobalData* const> arguments);
164 template <typename NativeT>
165 void ComputeAndCompareR3(XlaBuilder* builder,
166 const Array3D<NativeT>& expected,
167 absl::Span<GlobalData* const> arguments,
168 ErrorSpec error);
169
170 template <typename NativeT>
171 void ComputeAndCompareR4(XlaBuilder* builder,
172 const Array4D<NativeT>& expected,
173 absl::Span<GlobalData* const> arguments);
174 template <typename NativeT>
175 void ComputeAndCompareR4(XlaBuilder* builder,
176 const Array4D<NativeT>& expected,
177 absl::Span<GlobalData* const> arguments,
178 ErrorSpec error);
179
180 // Build and run the computation and compare the result with the given
181 // literal. shape_with_layout indicates the result layout to request when
182 // calling Execute.
183 void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
184 absl::Span<GlobalData* const> arguments,
185 const Shape* shape_with_layout = nullptr);
186 void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
187 absl::Span<GlobalData* const> arguments,
188 ErrorSpec error,
189 const Shape* shape_with_layout = nullptr);
190
191 // Build and run the computation and return the result as a literal.
192 // shape_with_layout indicates the result layout to request when calling
193 // Execute.
194 StatusOr<Literal> ComputeAndTransfer(
195 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
196 const Shape* shape_with_layout = nullptr);
197
198 // ComputeAndCompare variant which returns an error status.
199 Status ComputeAndCompareLiteralWithStatus(
200 XlaBuilder* builder, const Literal& expected,
201 absl::Span<GlobalData* const> arguments,
202 const Shape* shape_with_layout = nullptr);
203 Status ComputeAndCompareLiteralWithStatus(
204 XlaBuilder* builder, const Literal& expected,
205 absl::Span<GlobalData* const> arguments, ErrorSpec error,
206 const Shape* shape_with_layout = nullptr);
207
208 // Compare the result of the computation to a strings. In XLA strings are
209 // represented using rank-1 U8 shapes.
210 void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
211 absl::Span<GlobalData* const> arguments);
212
213 // Convenience method for running a built computation, transferring the
214 // result, and comparing it to the expected tuple literal.
215 void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
216 absl::Span<GlobalData* const> arguments);
217 void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
218 absl::Span<GlobalData* const> arguments,
219 ErrorSpec error);
220
221 // Convenience method for running a built computation and comparing the result
222 // with the reference result.
223 void ComputeAndCompare(XlaBuilder* builder,
224 absl::Span<const Literal> arguments);
225 void ComputeAndCompare(XlaBuilder* builder,
226 absl::Span<const Literal> arguments, ErrorSpec error);
227
228 // Create scalar operations for use in reductions.
229 XlaComputation CreateScalarRelu();
230 XlaComputation CreateScalarMax();
231 XlaComputation CreateScalarReluSensitivity();
232
233 // Special case convenience functions for creating filled arrays.
234
235 // Creates an array of pseudorandom values lying between the given minimum and
236 // maximum values.
237 template <typename NativeT>
238 std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
239 NativeT max_value, uint32 seed);
240 template <typename NativeT>
241 std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
242 const int cols,
243 NativeT min_value,
244 NativeT max_value,
245 uint32 seed);
246
247 // Creates a (rows x cols) array filled in the following form:
248 //
249 // [ 0 1 ... cols-1]
250 // [ 1,000 1,001 ... 1000.0 + cols-1]
251 // [ ... ... ... ...]
252 // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1]
253 //
254 // If provided, offset is added uniformly to every element (e.g. an offset of
255 // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
256 std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
257 const int cols,
258 float offset = 0.0);
259
260 // Creates a (rows x cols) array as above, padded out to
261 // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows
262 // and cols_padded > cols.
263 std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
264 const int rows, const int cols, const int rows_padded,
265 const int cols_padded);
266
267 // Creates a parameter instruction, transfers the literal for the parameter to
268 // server, then stores into "data_handle" the global handle for that
269 // parameter. When the use_bfloat16 flag is set but the literal has F32
270 // elements, the literal will be converted to BF16 before being transferred.
271 std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
272 int64 parameter_number, const Literal& literal, const string& name,
273 XlaBuilder* builder, XlaOp* data_handle);
274
275 // As above, but the caller can specify the device that the literal is
276 // transferred to. If device_handle is nullptr, the literal will be
277 // transferred to the default device.
278 std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
279 int64 parameter_number, const Literal& literal, const string& name,
280 const DeviceHandle* device_handle, XlaBuilder* builder,
281 XlaOp* data_handle);
282
283 // Creates a parameter instruction and sets the value that will be passed to
284 // the computation as specified. This function must be used for all parameters
285 // or none and no parameters must be passed when invoking the computation if
286 // using this mechanism. If using this mechanism, then each parameter must be
287 // set exactly once. The first added parameter gets index 0, then 1 and so on.
288 XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
289
290 template <class T>
AddParam(const Array<T> & argument,XlaBuilder * builder)291 XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
292 return AddParam(LiteralUtil::CreateFromArray(argument), builder);
293 }
294
295 // Creates a constant instruction with the given literal. When the
296 // use_bfloat16 flag is set but the literal has F32 elements, the elements
297 // will be converted to BF16s.
298 XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
299
300 // Creates a constant instruction with the given array. When the use_bfloat16
301 // flag is set but the array has float elements, the elements will be
302 // converted to bfloat16s.
303
304 template <typename NativeT>
CreateConstantFromArray(const Array<NativeT> & array,XlaBuilder * builder)305 XlaOp CreateConstantFromArray(const Array<NativeT>& array,
306 XlaBuilder* builder) {
307 return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
308 builder);
309 }
310
311 // Same as CreateConstantFromArray, but for scalars.
312 template <typename NativeT>
CreateConstantFromScalar(NativeT value,XlaBuilder * builder)313 XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
314 return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
315 builder);
316 }
317
318 // Creates a parameter instruction that wraps a given value and then stores
319 // into "data_handle" the global handle for that parameter.
320 //
321 // "parameter_number" is the parameter number.
322 // "name" is the name of the parameter instruction.
323 //
324 // When the use_bfloat16 flag is set but NativeT is float, the data will be
325 // converted to bfloat16.
326 template <typename NativeT>
327 std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
328 int64 parameter_number,
329 const string& name,
330 XlaBuilder* builder,
331 XlaOp* data_handle);
332
333 // Creates a parameter instruction that wraps the given values and then stores
334 // into "data_handle" the global handle for that parameter.
335 //
336 // "parameter_number" is the parameter number.
337 // "name" is the name of the parameter instruction.
338 //
339 // When the use_bfloat16 flag is set but NativeT is float, the data will be
340 // converted to bfloat16.
341 template <typename NativeT>
342 std::unique_ptr<GlobalData> CreateR1Parameter(
343 absl::Span<const NativeT> values, int64 parameter_number,
344 const string& name, XlaBuilder* builder, XlaOp* data_handle);
345
346 // Creates a parameter instruction that wraps the given constant array
347 // "array_2d" and then stores to "data_handle" the global handle for that
348 // parameter.
349 //
350 // "parameter_number" is the parameter number.
351 // "name" is the name of the parameter instruction.
352 //
353 // When the use_bfloat16 flag is set but NativeT is float, the data will be
354 // converted to bfloat16.
355 template <typename NativeT>
356 std::unique_ptr<GlobalData> CreateR2Parameter(
357 const Array2D<NativeT>& array_2d, int64 parameter_number,
358 const string& name, XlaBuilder* builder, XlaOp* data_handle);
359
360 // Creates a parameter instruction that wraps the given constant array
361 // "array_3d" and then stores to "data_handle" the global handle for that
362 // parameter.
363 //
364 // "parameter_number" is the parameter number.
365 // "name" is the name of the parameter instruction.
366 //
367 // When the use_bfloat16 flag is set but NativeT is float, the data will be
368 // converted to bfloat16.
369 template <typename NativeT>
370 std::unique_ptr<GlobalData> CreateR3Parameter(
371 const Array3D<NativeT>& array_3d, int64 parameter_number,
372 const string& name, XlaBuilder* builder, XlaOp* data_handle);
373
374 // Getter and setter for the use_bfloat16 flag, which indicates whether to run
375 // tests with all float-type input/output converted to bfloat16.
use_bfloat16()376 bool use_bfloat16() const { return use_bfloat16_; }
set_use_bfloat16(bool value)377 void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
378
379 // The float type used in this test, BF16 or F32 according to use_bfloat16.
FloatType()380 PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
381
382 // Executes the computation and calculates the expected reference value using
383 // the reference client. Returns two literals in the order of (expected,
384 // actual).
385 StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
386 XlaBuilder* builder, absl::Span<const Literal> arguments);
387
388 LocalClient* client_;
389 LocalClient* ref_client_; // To compute reference result.
390 ExecutionOptions execution_options_;
391
392 private:
393 Status ComputeAndCompareLiteralWithAllOutputLayouts(
394 const xla::XlaComputation& computation, const Literal& expected,
395 absl::Span<GlobalData* const> arguments,
396 const std::function<void(const Literal& actual,
397 const string& error_message)>& verify_output);
398 Status ComputeAndCompareLiteralWithAllInputLayouts(
399 const xla::XlaComputation& computation, const Literal& expected,
400 absl::Span<GlobalData* const> arguments,
401 const std::function<void(const Literal& actual,
402 const string& error_message)>& verify_output,
403 const Shape* output_with_layout = nullptr);
404
405 // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true.
406 Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
407 Shape MaybeConvertShapeToBfloat16(const Shape& shape);
408
409 // Whether to run tests with all float-type input/output converted to
410 // bfloat16.
411 bool use_bfloat16_ = false;
412
413 // Arguments to be passed to the computation when it runs.
414 std::vector<Literal> arguments_;
415 };
416
417 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments)418 void ClientLibraryTestBase::ComputeAndCompareR0(
419 XlaBuilder* builder, NativeT expected,
420 absl::Span<GlobalData* const> arguments) {
421 Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
422 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
423 arguments);
424 }
425
426 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)427 void ClientLibraryTestBase::ComputeAndCompareR0(
428 XlaBuilder* builder, NativeT expected,
429 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
430 static_assert(std::is_same<NativeT, float>::value ||
431 std::is_same<NativeT, double>::value ||
432 std::is_same<NativeT, bfloat16>::value ||
433 std::is_same<NativeT, half>::value ||
434 std::is_same<NativeT, complex64>::value ||
435 std::is_same<NativeT, complex128>::value,
436 "Float or complex type required when specifying an ErrorSpec");
437 Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
438 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
439 arguments, error);
440 }
441
442 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)443 void ClientLibraryTestBase::ComputeAndCompareR1(
444 XlaBuilder* builder, absl::Span<const NativeT> expected,
445 absl::Span<GlobalData* const> arguments) {
446 Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
447 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
448 arguments);
449 }
450
451 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)452 void ClientLibraryTestBase::ComputeAndCompareR1(
453 XlaBuilder* builder, absl::Span<const NativeT> expected,
454 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
455 static_assert(std::is_same<NativeT, float>::value ||
456 std::is_same<NativeT, double>::value ||
457 std::is_same<NativeT, bfloat16>::value ||
458 std::is_same<NativeT, half>::value ||
459 std::is_same<NativeT, complex64>::value ||
460 std::is_same<NativeT, complex128>::value,
461 "Float or complex type required when specifying an ErrorSpec");
462 Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
463 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
464 arguments, error);
465 }
466
467 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)468 void ClientLibraryTestBase::ComputeAndCompareR2(
469 XlaBuilder* builder, const Array2D<NativeT>& expected,
470 absl::Span<GlobalData* const> arguments) {
471 Literal expected_literal =
472 LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
473 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
474 arguments);
475 }
476
477 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)478 void ClientLibraryTestBase::ComputeAndCompareR2(
479 XlaBuilder* builder, const Array2D<NativeT>& expected,
480 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
481 static_assert(std::is_same<NativeT, float>::value ||
482 std::is_same<NativeT, double>::value ||
483 std::is_same<NativeT, bfloat16>::value ||
484 std::is_same<NativeT, half>::value ||
485 std::is_same<NativeT, complex64>::value ||
486 std::is_same<NativeT, complex128>::value,
487 "Float or complex type required when specifying an ErrorSpec");
488 Literal expected_literal =
489 LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
490 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
491 arguments, error);
492 }
493
494 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments)495 void ClientLibraryTestBase::ComputeAndCompareR3(
496 XlaBuilder* builder, const Array3D<NativeT>& expected,
497 absl::Span<GlobalData* const> arguments) {
498 Literal expected_literal =
499 LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
500 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
501 arguments);
502 }
503
504 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)505 void ClientLibraryTestBase::ComputeAndCompareR3(
506 XlaBuilder* builder, const Array3D<NativeT>& expected,
507 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
508 static_assert(std::is_same<NativeT, float>::value ||
509 std::is_same<NativeT, double>::value ||
510 std::is_same<NativeT, bfloat16>::value ||
511 std::is_same<NativeT, half>::value ||
512 std::is_same<NativeT, complex64>::value ||
513 std::is_same<NativeT, complex128>::value,
514 "Float or complex type required when specifying an ErrorSpec");
515 Literal expected_literal =
516 LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
517 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
518 arguments, error);
519 }
520
521 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments)522 void ClientLibraryTestBase::ComputeAndCompareR4(
523 XlaBuilder* builder, const Array4D<NativeT>& expected,
524 absl::Span<GlobalData* const> arguments) {
525 Literal expected_literal =
526 LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
527 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
528 arguments);
529 }
530
531 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)532 void ClientLibraryTestBase::ComputeAndCompareR4(
533 XlaBuilder* builder, const Array4D<NativeT>& expected,
534 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
535 static_assert(std::is_same<NativeT, float>::value ||
536 std::is_same<NativeT, double>::value ||
537 std::is_same<NativeT, bfloat16>::value ||
538 std::is_same<NativeT, half>::value ||
539 std::is_same<NativeT, complex64>::value ||
540 std::is_same<NativeT, complex128>::value,
541 "Float or complex type required when specifying an ErrorSpec");
542 Literal expected_literal =
543 LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
544 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
545 arguments, error);
546 }
547
548 template <typename NativeT>
CreateR0Parameter(NativeT value,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)549 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
550 NativeT value, int64 parameter_number, const string& name,
551 XlaBuilder* builder, XlaOp* data_handle) {
552 Literal literal = LiteralUtil::CreateR0(value);
553 if (use_bfloat16_ && literal.shape().element_type() == F32) {
554 literal = LiteralUtil::ConvertF32ToBF16(literal);
555 }
556 std::unique_ptr<GlobalData> data =
557 client_->TransferToServer(literal).ConsumeValueOrDie();
558 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
559 return data;
560 }
561
562 template <typename NativeT>
CreateR1Parameter(absl::Span<const NativeT> values,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)563 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
564 absl::Span<const NativeT> values, int64 parameter_number,
565 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
566 Literal literal = LiteralUtil::CreateR1(values);
567 if (use_bfloat16_ && literal.shape().element_type() == F32) {
568 literal = LiteralUtil::ConvertF32ToBF16(literal);
569 }
570 std::unique_ptr<GlobalData> data =
571 client_->TransferToServer(literal).ConsumeValueOrDie();
572 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
573 return data;
574 }
575
576 template <typename NativeT>
CreateR2Parameter(const Array2D<NativeT> & array_2d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)577 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
578 const Array2D<NativeT>& array_2d, int64 parameter_number,
579 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
580 Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
581 if (use_bfloat16_ && literal.shape().element_type() == F32) {
582 literal = LiteralUtil::ConvertF32ToBF16(literal);
583 }
584 std::unique_ptr<GlobalData> data =
585 client_->TransferToServer(literal).ConsumeValueOrDie();
586 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
587 return data;
588 }
589
590 template <typename NativeT>
CreateR3Parameter(const Array3D<NativeT> & array_3d,int64 parameter_number,const string & name,XlaBuilder * builder,XlaOp * data_handle)591 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
592 const Array3D<NativeT>& array_3d, int64 parameter_number,
593 const string& name, XlaBuilder* builder, XlaOp* data_handle) {
594 Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
595 if (use_bfloat16_ && literal.shape().element_type() == F32) {
596 literal = LiteralUtil::ConvertF32ToBF16(literal);
597 }
598 std::unique_ptr<GlobalData> data =
599 client_->TransferToServer(literal).ConsumeValueOrDie();
600 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
601 return data;
602 }
603
604 template <typename NativeT>
CreatePseudorandomR1(const int width,NativeT min_value,NativeT max_value,uint32 seed)605 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
606 const int width, NativeT min_value, NativeT max_value, uint32 seed) {
607 std::vector<NativeT> result(width);
608 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
609 for (int i = 0; i < width; ++i) {
610 result[i] = generator.get();
611 }
612 return result;
613 }
614
615 template <typename NativeT>
CreatePseudorandomR2(const int rows,const int cols,NativeT min_value,NativeT max_value,uint32 seed)616 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
617 const int rows, const int cols, NativeT min_value, NativeT max_value,
618 uint32 seed) {
619 auto result = absl::make_unique<Array2D<NativeT>>(rows, cols);
620 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
621 for (int y = 0; y < rows; ++y) {
622 for (int x = 0; x < cols; ++x) {
623 (*result)(y, x) = generator.get();
624 }
625 }
626 return result;
627 }
628
629 } // namespace xla
630
631 #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
632