• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 
31 namespace xla {
32 namespace {
33 
34 class ComparatorsTest : public ClientLibraryTestBase {
35  public:
ComparatorsTest()36   ComparatorsTest() : builder_(TestName()) {}
builder()37   XlaBuilder* builder() { return &builder_; }
38 
39  private:
40   XlaBuilder builder_;
41 };
42 
43 template <
44     PrimitiveType type,
45     typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
BuildComparatorAndComparisons(ComparatorsTest * test,bool compare_less_than,absl::InlinedVector<bool,10> * expected)46 void BuildComparatorAndComparisons(ComparatorsTest* test,
47                                    bool compare_less_than,
48                                    absl::InlinedVector<bool, 10>* expected) {
49   auto compare = compare_less_than
50                      ? CreateScalarLtComputation({type}, test->builder())
51                      : CreateScalarGtComputation({type}, test->builder());
52 
53   auto negative_nan = ConstantR0<T>(
54       test->builder(), -T(std::numeric_limits<float>::quiet_NaN()));
55   auto positive_nan = ConstantR0<T>(test->builder(),
56                                     T(std::numeric_limits<float>::quiet_NaN()));
57   auto negative_zero = ConstantR0<T>(test->builder(), T(-0.));
58   auto positive_zero = ConstantR0<T>(test->builder(), T(0.));
59   auto negative_infinity = MinValue(test->builder(), type);
60   auto positive_infinity = MaxValue(test->builder(), type);
61 
62   // List the values in the expected sorting order from smallest to largest.
63   std::vector<XlaOp> all_constants{negative_nan,      negative_infinity,
64                                    negative_zero,     positive_zero,
65                                    positive_infinity, positive_nan};
66 
67   // Do pairwise comparisons.
68   std::vector<XlaOp> all_comparisons;
69   for (const XlaOp& lhs_constant : all_constants) {
70     for (const XlaOp& rhs_constant : all_constants) {
71       all_comparisons.push_back(Broadcast(
72           Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1}));
73     }
74   }
75 
76   // Concantenate the comparison results.
77   ConcatInDim(test->builder(), all_comparisons, 0);
78 
79   // If we use less-than comparisons, we expect the comparison to result in true
80   // if the lhs value to be compared appears earlier in 'all_constants' than the
81   // rhs value. Likewise, if we use greater-than comparisons, we expect the
82   // comparison to return true if the rhs value appears earlier in
83   // 'all_constants' than the lhs value.
84   expected->clear();
85   for (int i = 0; i < all_constants.size(); ++i) {
86     for (int j = 0; j < all_constants.size(); ++j) {
87       expected->push_back(compare_less_than ? i < j : i > j);
88     }
89   }
90 }
91 
XLA_TEST_F(ComparatorsTest,CompareLtBF16)92 XLA_TEST_F(ComparatorsTest, CompareLtBF16) {
93   absl::InlinedVector<bool, 10> expected;
94   BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/true,
95                                       &expected);
96   ComputeAndCompareR1<bool>(builder(), expected, {});
97 }
98 
XLA_TEST_F(ComparatorsTest,CompareGtBF16)99 XLA_TEST_F(ComparatorsTest, CompareGtBF16) {
100   absl::InlinedVector<bool, 10> expected;
101   BuildComparatorAndComparisons<BF16>(this, /*compare_less_than=*/false,
102                                       &expected);
103   ComputeAndCompareR1<bool>(builder(), expected, {});
104 }
105 
XLA_TEST_F(ComparatorsTest,CompareLtF16)106 XLA_TEST_F(ComparatorsTest, CompareLtF16) {
107   absl::InlinedVector<bool, 10> expected;
108   BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/true,
109                                      &expected);
110   ComputeAndCompareR1<bool>(builder(), expected, {});
111 }
112 
XLA_TEST_F(ComparatorsTest,CompareGtF16)113 XLA_TEST_F(ComparatorsTest, CompareGtF16) {
114   absl::InlinedVector<bool, 10> expected;
115   BuildComparatorAndComparisons<F16>(this, /*compare_less_than=*/false,
116                                      &expected);
117   ComputeAndCompareR1<bool>(builder(), expected, {});
118 }
119 
XLA_TEST_F(ComparatorsTest,CompareLtF32)120 XLA_TEST_F(ComparatorsTest, CompareLtF32) {
121   absl::InlinedVector<bool, 10> expected;
122   BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/true,
123                                      &expected);
124   ComputeAndCompareR1<bool>(builder(), expected, {});
125 }
126 
XLA_TEST_F(ComparatorsTest,CompareGtF32)127 XLA_TEST_F(ComparatorsTest, CompareGtF32) {
128   absl::InlinedVector<bool, 10> expected;
129   BuildComparatorAndComparisons<F32>(this, /*compare_less_than=*/false,
130                                      &expected);
131   ComputeAndCompareR1<bool>(builder(), expected, {});
132 }
133 
XLA_TEST_F(ComparatorsTest,CompareLtF64)134 XLA_TEST_F(ComparatorsTest, CompareLtF64) {
135   absl::InlinedVector<bool, 10> expected;
136   BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/true,
137                                      &expected);
138   ComputeAndCompareR1<bool>(builder(), expected, {});
139 }
140 
XLA_TEST_F(ComparatorsTest,CompareGtF64)141 XLA_TEST_F(ComparatorsTest, CompareGtF64) {
142   absl::InlinedVector<bool, 10> expected;
143   BuildComparatorAndComparisons<F64>(this, /*compare_less_than=*/false,
144                                      &expected);
145   ComputeAndCompareR1<bool>(builder(), expected, {});
146 }
147 
148 }  // namespace
149 }  // namespace xla
150