1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17
18 #include <limits>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30
31 namespace xla {
32 namespace {
33
34 class ComparatorsTest : public ClientLibraryTestBase {
35 public:
ComparatorsTest()36 ComparatorsTest() : builder_(TestName()) {}
builder()37 XlaBuilder* builder() { return &builder_; }
38
39 private:
40 XlaBuilder builder_;
41 };
42
43 template <
44 PrimitiveType type,
45 typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
BuildComparatorAndComparisons(ComparatorsTest * test,bool compare_less_than,absl::InlinedVector<bool,10> * expected)46 void BuildComparatorAndComparisons(ComparatorsTest* test,
47 bool compare_less_than,
48 absl::InlinedVector<bool, 10>* expected) {
49 auto compare = compare_less_than
50 ? CreateScalarLtComputation({type}, test->builder())
51 : CreateScalarGtComputation({type}, test->builder());
52
53 auto negative_nan = ConstantR0<T>(
54 test->builder(), -T(std::numeric_limits<float>::quiet_NaN()));
55 auto positive_nan = ConstantR0<T>(test->builder(),
56 T(std::numeric_limits<float>::quiet_NaN()));
57 auto negative_zero = ConstantR0<T>(test->builder(), T(-0.));
58 auto positive_zero = ConstantR0<T>(test->builder(), T(0.));
59 auto negative_infinity = MinValue(test->builder(), type);
60 auto positive_infinity = MaxValue(test->builder(), type);
61
62 // List the values in the expected sorting order from smallest to largest.
63 std::vector<XlaOp> all_constants{negative_nan, negative_infinity,
64 negative_zero, positive_zero,
65 positive_infinity, positive_nan};
66
67 // Do pairwise comparisons.
68 std::vector<XlaOp> all_comparisons;
69 for (const XlaOp& lhs_constant : all_constants) {
70 for (const XlaOp& rhs_constant : all_constants) {
71 all_comparisons.push_back(Broadcast(
72 Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1}));
73 }
74 }
75
76 // Concantenate the comparison results.
77 ConcatInDim(test->builder(), all_comparisons, 0);
78
79 // If we use less-than comparisons, we expect the comparison to result in true
80 // if the lhs value to be compared appears earlier in 'all_constants' than the
81 // rhs value. Likewise, if we use greater-than comparisons, we expect the
82 // comparison to return true if the rhs value appears earlier in
83 // 'all_constants' than the lhs value.
84 expected->clear();
85 for (int i = 0; i < all_constants.size(); ++i) {
86 for (int j = 0; j < all_constants.size(); ++j) {
87 expected->push_back(compare_less_than ? i < j : i > j);
88 }
89 }
90 }
91
XLA_TEST_F(ComparatorsTest,CompareLtBF16)92 XLA_TEST_F(ComparatorsTest, CompareLtBF16) {
93 absl::InlinedVector<bool, 10> expected;
94 BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/true,
95 &expected);
96 ComputeAndCompareR1<bool>(builder(), expected, {});
97 }
98
XLA_TEST_F(ComparatorsTest,CompareGtBF16)99 XLA_TEST_F(ComparatorsTest, CompareGtBF16) {
100 absl::InlinedVector<bool, 10> expected;
101 BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/false,
102 &expected);
103 ComputeAndCompareR1<bool>(builder(), expected, {});
104 }
105
XLA_TEST_F(ComparatorsTest,CompareLtF16)106 XLA_TEST_F(ComparatorsTest, CompareLtF16) {
107 absl::InlinedVector<bool, 10> expected;
108 BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/true,
109 &expected);
110 ComputeAndCompareR1<bool>(builder(), expected, {});
111 }
112
XLA_TEST_F(ComparatorsTest,CompareGtF16)113 XLA_TEST_F(ComparatorsTest, CompareGtF16) {
114 absl::InlinedVector<bool, 10> expected;
115 BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/false,
116 &expected);
117 ComputeAndCompareR1<bool>(builder(), expected, {});
118 }
119
XLA_TEST_F(ComparatorsTest,CompareLtF32)120 XLA_TEST_F(ComparatorsTest, CompareLtF32) {
121 absl::InlinedVector<bool, 10> expected;
122 BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/true,
123 &expected);
124 ComputeAndCompareR1<bool>(builder(), expected, {});
125 }
126
XLA_TEST_F(ComparatorsTest,CompareGtF32)127 XLA_TEST_F(ComparatorsTest, CompareGtF32) {
128 absl::InlinedVector<bool, 10> expected;
129 BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/false,
130 &expected);
131 ComputeAndCompareR1<bool>(builder(), expected, {});
132 }
133
XLA_TEST_F(ComparatorsTest,CompareLtF64)134 XLA_TEST_F(ComparatorsTest, CompareLtF64) {
135 absl::InlinedVector<bool, 10> expected;
136 BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/true,
137 &expected);
138 ComputeAndCompareR1<bool>(builder(), expected, {});
139 }
140
XLA_TEST_F(ComparatorsTest,CompareGtF64)141 XLA_TEST_F(ComparatorsTest, CompareGtF64) {
142 absl::InlinedVector<bool, 10> expected;
143 BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/false,
144 &expected);
145 ComputeAndCompareR1<bool>(builder(), expected, {});
146 }
147
148 } // namespace
149 } // namespace xla
150