1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17
18 #include <limits>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 namespace xla {
33 namespace {
34
35 using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&,
36 absl::Span<const int64>);
37
BitcastConvertFloatingPointToIntegral(const XlaOp & value,int64 bit_width)38 XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
39 int64 bit_width) {
40 PrimitiveType signed_type;
41 PrimitiveType unsigned_type;
42 XlaOp max_value;
43 switch (bit_width) {
44 case 16:
45 max_value =
46 ConstantR0(value.builder(),
47 static_cast<uint16>(std::numeric_limits<int16>::max()));
48 signed_type = S16;
49 unsigned_type = U16;
50 break;
51 case 32:
52 max_value =
53 ConstantR0(value.builder(),
54 static_cast<uint32>(std::numeric_limits<int32>::max()));
55 signed_type = S32;
56 unsigned_type = U32;
57 break;
58 case 64:
59 max_value =
60 ConstantR0(value.builder(),
61 static_cast<uint64>(std::numeric_limits<int64>::max()));
62 signed_type = S64;
63 unsigned_type = U64;
64 break;
65 default:
66 return value.builder()->ReportError(
67 InvalidArgument("Invalid bit width %lld for Comparator floating "
68 "point parameter.",
69 bit_width));
70 }
71 // Switch from a floating point value to a integer value in such a way that
72 // when using the integer value to compare, we get the same result for normal
73 // values, and -Nan is treated as the smallest value, and Nan is treated as
74 // the largest value.
75 // If f is a float, and
76 // x = bit_cast<int32>(f);
77 // y = x < 0 ? numeric_limits<int32>::max() - x : x;
78 // then y is ordered as an int32 such that finite values have the obvious
79 // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
80 // and end of the ordering.
81 // Note that in order to avoid -x to overflow, we calculate
82 // numeric_limits<int32>::max() - x as unsigned, and then convert back to
83 // signed.
84 auto signed_value = BitcastConvertType(value, signed_type);
85 auto unsigned_value = BitcastConvertType(value, unsigned_type);
86 auto flipped_value =
87 BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
88 auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type));
89 return Select(is_negative, flipped_value, signed_value);
90 }
91
CreateScalarComparisonComputation(const string & name,const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder,XlaOpGenerator generator)92 XlaComputation CreateScalarComparisonComputation(
93 const string& name, const std::vector<PrimitiveType>& operand_types,
94 XlaBuilder* builder, XlaOpGenerator generator) {
95 // Create a default computation where we compare only the first two
96 // parameters of type 'operand_types[0]'.
97 auto b = builder->CreateSubBuilder(name);
98 if (operand_types.empty()) {
99 b->ReportError(InvalidArgument("operand_types should not be empty"));
100 return b->BuildAndNoteError();
101 }
102
103 int64 parameter_count = 0;
104 XlaOp first_lhs_param;
105 XlaOp first_rhs_param;
106
107 // For each type in 'operand_types' we create two parameters of this type. The
108 // idea is that this computation can be used by n-ary Sort, and potentially
109 // should support comparing also the other operands of sort. In this default
110 // computation, however, we will not actually use any parameters except the
111 // first two.
112 for (auto operand_type : operand_types) {
113 auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
114 auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape,
115 absl::StrCat("p.", parameter_count, ".lhs"));
116 auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
117 absl::StrCat("p.", parameter_count, ".rhs"));
118 if (parameter_count == 0) {
119 first_lhs_param = lhs_param;
120 first_rhs_param = rhs_param;
121 }
122 ++parameter_count;
123 }
124 if (primitive_util::IsFloatingPointType(operand_types[0])) {
125 PrimitiveType compare_type = operand_types[0];
126 // Special-case handling for BF16. We currently do not support direct
127 // comparisons with BF16, so we convert to F32 and then use the F32
128 // comparison logic.
129 if (compare_type == BF16) {
130 compare_type = F32;
131 first_lhs_param = ConvertElementType(first_lhs_param, F32);
132 first_rhs_param = ConvertElementType(first_rhs_param, F32);
133 }
134 int64 bit_width = primitive_util::BitWidth(compare_type);
135 first_lhs_param =
136 BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width);
137 first_rhs_param =
138 BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width);
139 }
140 generator(first_lhs_param, first_rhs_param, {});
141 return b->BuildAndNoteError();
142 }
143 } // namespace
144
145 // Creates a scalar less-than computation and returns it.
CreateScalarLtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)146 XlaComputation CreateScalarLtComputation(
147 const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
148 return CreateScalarComparisonComputation("compare-less-than", operand_types,
149 builder, Lt);
150 }
151
152 // Creates a scalar greater-than computation and returns it.
CreateScalarGtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)153 XlaComputation CreateScalarGtComputation(
154 const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
155 return CreateScalarComparisonComputation("compare-greater-than",
156 operand_types, builder, Gt);
157 }
158
159 } // namespace xla
160