• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17 
18 #include <limits>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 namespace {
34 
35 using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&,
36                                  absl::Span<const int64>);
37 
BitcastConvertFloatingPointToIntegral(const XlaOp & value,int64 bit_width)38 XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
39                                             int64 bit_width) {
40   PrimitiveType signed_type;
41   PrimitiveType unsigned_type;
42   XlaOp max_value;
43   switch (bit_width) {
44     case 16:
45       max_value =
46           ConstantR0(value.builder(),
47                      static_cast<uint16>(std::numeric_limits<int16>::max()));
48       signed_type = S16;
49       unsigned_type = U16;
50       break;
51     case 32:
52       max_value =
53           ConstantR0(value.builder(),
54                      static_cast<uint32>(std::numeric_limits<int32>::max()));
55       signed_type = S32;
56       unsigned_type = U32;
57       break;
58     case 64:
59       max_value =
60           ConstantR0(value.builder(),
61                      static_cast<uint64>(std::numeric_limits<int64>::max()));
62       signed_type = S64;
63       unsigned_type = U64;
64       break;
65     default:
66       return value.builder()->ReportError(
67           InvalidArgument("Invalid bit width %lld for Comparator floating "
68                           "point parameter.",
69                           bit_width));
70   }
71   // Switch from a floating point value to a integer value in such a way that
72   // when using the integer value to compare, we get the same result for normal
73   // values, and -Nan is treated as the smallest value, and Nan is treated as
74   // the largest value.
75   // If f is a float, and
76   // x = bit_cast<int32>(f);
77   // y = x < 0 ? numeric_limits<int32>::max() - x : x;
78   // then y is ordered as an int32 such that finite values have the obvious
79   // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
80   // and end of the ordering.
81   // Note that in order to avoid -x to overflow, we calculate
82   // numeric_limits<int32>::max() - x as unsigned, and then convert back to
83   // signed.
84   auto signed_value = BitcastConvertType(value, signed_type);
85   auto unsigned_value = BitcastConvertType(value, unsigned_type);
86   auto flipped_value =
87       BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
88   auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type));
89   return Select(is_negative, flipped_value, signed_value);
90 }
91 
CreateScalarComparisonComputation(const string & name,const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder,XlaOpGenerator generator)92 XlaComputation CreateScalarComparisonComputation(
93     const string& name, const std::vector<PrimitiveType>& operand_types,
94     XlaBuilder* builder, XlaOpGenerator generator) {
95   // Create a default computation where we compare only the first two
96   // parameters of type 'operand_types[0]'.
97   auto b = builder->CreateSubBuilder(name);
98   if (operand_types.empty()) {
99     b->ReportError(InvalidArgument("operand_types should not be empty"));
100     return b->BuildAndNoteError();
101   }
102 
103   int64 parameter_count = 0;
104   XlaOp first_lhs_param;
105   XlaOp first_rhs_param;
106 
107   // For each type in 'operand_types' we create two parameters of this type. The
108   // idea is that this computation can be used by n-ary Sort, and potentially
109   // should support comparing also the other operands of sort. In this default
110   // computation, however, we will not actually use any parameters except the
111   // first two.
112   for (auto operand_type : operand_types) {
113     auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
114     auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape,
115                                absl::StrCat("p.", parameter_count, ".lhs"));
116     auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
117                                absl::StrCat("p.", parameter_count, ".rhs"));
118     if (parameter_count == 0) {
119       first_lhs_param = lhs_param;
120       first_rhs_param = rhs_param;
121     }
122     ++parameter_count;
123   }
124   if (primitive_util::IsFloatingPointType(operand_types[0])) {
125     PrimitiveType compare_type = operand_types[0];
126     // Special-case handling for BF16. We currently do not support direct
127     // comparisons with BF16, so we convert to F32 and then use the F32
128     // comparison logic.
129     if (compare_type == BF16) {
130       compare_type = F32;
131       first_lhs_param = ConvertElementType(first_lhs_param, F32);
132       first_rhs_param = ConvertElementType(first_rhs_param, F32);
133     }
134     int64 bit_width = primitive_util::BitWidth(compare_type);
135     first_lhs_param =
136         BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width);
137     first_rhs_param =
138         BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width);
139   }
140   generator(first_lhs_param, first_rhs_param, {});
141   return b->BuildAndNoteError();
142 }
143 }  // namespace
144 
145 // Creates a scalar less-than computation and returns it.
CreateScalarLtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)146 XlaComputation CreateScalarLtComputation(
147     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
148   return CreateScalarComparisonComputation("compare-less-than", operand_types,
149                                            builder, Lt);
150 }
151 
152 // Creates a scalar greater-than computation and returns it.
CreateScalarGtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)153 XlaComputation CreateScalarGtComputation(
154     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
155   return CreateScalarComparisonComputation("compare-greater-than",
156                                            operand_types, builder, Gt);
157 }
158 
159 }  // namespace xla
160