1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
18
19 #include <optional>
20 #include <string>
21 #include <type_traits>
22
23 #include "absl/base/attributes.h"
24 #include "absl/base/macros.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/bfloat16.h"
31
32 namespace xla {
33
34 // A utility class for primitive comparisons. A comparison includes three
35 // components: the type of the elements being compared (F32, S16, etc), whether
36 // it is a partial or total order comparison, and the actual comparison operator
37 // (==, <=, >, etc).
38 //
39 // Note that integer comparisons are always total order. Float comparisons can
40 // be either total or partial order.
41 //
42 // Some examples:
43 //
44 // Comparison a(
45 // Comparison::Direction::kLt,
46 // xla::PrimitiveType::BF16,
47 // Comparison::Order::kTotal
48 // );
49 // a.ToString(); /* ".LT.BF16.TOTALORDER" */
50 //
51 // Comparison b(Comparison::Direction::kEq, xla::PrimitiveType::U32);
52 // b.IsTotalOrder(); /* true */
53 class Comparison {
54 public:
55 // Represents the ordering of the comparison.
56 enum class Order : uint8_t {
57 // https://en.wikipedia.org/wiki/Total_order
58 kTotal,
59 // https://en.wikipedia.org/wiki/Partially_ordered_set
60 kPartial,
61 };
62
63 // Represents different comparison operations.
64 enum class Direction : uint8_t {
65 kEq,
66 kNe,
67 kGe,
68 kGt,
69 kLe,
70 kLt,
71 };
72
73 // (DEPRECATED) Represents the type of comparison. Prefer xla::PrimitiveType
74 // and Comparison::Order, since there are multiple floating point
75 // representations that support total ordering.
76 enum class [[deprecated("Use PrimitiveType and Order")]] Type : uint8_t{
77 kFloat,
78 kFloatTotalOrder,
79 kSigned,
80 kUnsigned,
81 };
82
83 Comparison() = delete;
84
85 // This will default to the expected behavior for Comparison::Order: integers
86 // will use total ordering, and floats will use partial ordering.
87 explicit Comparison(Direction dir, PrimitiveType type);
88
89 // Pass in a Comparison::Order to specify a non-default ordering, e.g., some
90 // targets may support total order floating point type comparisons.
91 explicit Comparison(Direction dir, PrimitiveType type, Order order);
92
93 // Returns a comparison with a primitive type matching the Comparison::Type
94 // and using a default bit width of 32. For example,
95 // Comparison(Direction::kLt, Type::kFloat).PrimitiveType() /* F32 */
96 [[deprecated(
97 "Use Comparison(Comparison::Direction, "
98 "PrimitiveType)")]] explicit Comparison(Direction dir, Type type);
99
GetDirection()100 inline Direction GetDirection() const { return dir_; }
GetPrimitiveType()101 inline PrimitiveType GetPrimitiveType() const { return primitive_type_; }
GetOrder()102 inline Order GetOrder() const { return order_; }
103
GetType()104 [[deprecated("Use GetPrimitiveType() and GetOrder()")]] inline Type GetType()
105 const {
106 return type_;
107 }
108
IsEq()109 inline bool IsEq() const { return dir_ == Direction::kEq; }
IsNe()110 inline bool IsNe() const { return dir_ == Direction::kNe; }
IsGe()111 inline bool IsGe() const { return dir_ == Direction::kGe; }
IsGt()112 inline bool IsGt() const { return dir_ == Direction::kGt; }
IsLt()113 inline bool IsLt() const { return dir_ == Direction::kLt; }
IsTotalOrder()114 inline bool IsTotalOrder() const { return order_ == Order::kTotal; }
IsPartialOrder()115 inline bool IsPartialOrder() const { return order_ == Order::kPartial; }
116
117 // Returns whether this is a floating point total order comparison.
IsF32TotalOrder()118 inline bool IsF32TotalOrder() const {
119 return primitive_type_ == PrimitiveType::F32 && IsTotalOrder();
120 }
IsBf16TotalOrder()121 inline bool IsBf16TotalOrder() const {
122 return primitive_type_ == PrimitiveType::BF16 && IsTotalOrder();
123 }
124
125 // Returns whether this is a standard comparison, i.e., what you would expect
126 // as the industry standard on most architectures.
IsStandardF32()127 inline bool IsStandardF32() const {
128 return primitive_type_ == PrimitiveType::F32 && IsPartialOrder();
129 }
IsStandardBf16()130 inline bool IsStandardBf16() const {
131 return primitive_type_ == PrimitiveType::BF16 && IsPartialOrder();
132 }
IsStandardS32()133 inline bool IsStandardS32() const {
134 return primitive_type_ == PrimitiveType::S32 && IsTotalOrder();
135 }
IsStandardU32()136 inline bool IsStandardU32() const {
137 return primitive_type_ == PrimitiveType::U32 && IsTotalOrder();
138 }
139
IsIntegralPrimitiveType()140 inline bool IsIntegralPrimitiveType() const {
141 return primitive_util::IsIntegralType(primitive_type_);
142 }
IsFloatingPointPrimitiveType()143 inline bool IsFloatingPointPrimitiveType() const {
144 return primitive_util::IsFloatingPointType(primitive_type_);
145 }
146
147 // Returns whether (a dir a) is always true for this comparison.
148 bool IsReflexive() const;
149
150 // Returns whether (a dir a) is always false for this comparison.
151 bool IsAntireflexive() const;
152
153 // Gets the converse of the given comparison direction (e.g. >= turns to <=).
154 // Useful when commuting operands to get constants into immediate-accepting
155 // positions in the ISA.
156 Comparison Converse() const;
157
158 // Gets the inverse of the given comparison if it exists (e.g. >= turns to <).
159 // Returns optional value because not all inversions may be supported.
160 std::optional<Comparison> Inverse() const;
161
162 // Returns a string version of this comparison, e.g., ".GT.F32.TOTALORDER"
163 std::string ToString(std::string prefix1 = ".", std::string prefix2 = ".",
164 std::string prefix3 = ".") const;
165
166 // Returns a comparison operator: (T, T) -> bool for this Comparison's
167 // Direction.
168 template <typename T>
GetComparator()169 std::function<bool(T, T)> GetComparator() const {
170 switch (GetDirection()) {
171 case Direction::kEq:
172 return std::equal_to<T>();
173 case Direction::kNe:
174 return std::not_equal_to<T>();
175 case Direction::kGe:
176 return std::greater_equal<T>();
177 case Direction::kGt:
178 return std::greater<T>();
179 case Direction::kLe:
180 return std::less_equal<T>();
181 case Direction::kLt:
182 return std::less<T>();
183 }
184 }
185
186 // Applies the comparison from this Comparison's direction and ordering for
187 // integral types.
188 template <typename T, absl::enable_if_t<std::is_integral<T>::value, int> = 0>
Compare(const T a,const T b)189 bool Compare(const T a, const T b) const {
190 CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_));
191 return GetComparator<T>()(a, b);
192 }
193
194 // Applies the comparison from this Comparison's direction and ordering
195 // for floating point types.
196 template <typename T,
197 absl::enable_if_t<std::is_floating_point<T>::value ||
198 std::is_same<T, xla::bfloat16>::value,
199 int> = 0>
Compare(const T a,const T b)200 bool Compare(const T a, const T b) const {
201 CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_));
202 if (IsTotalOrder()) {
203 // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
204 // Reference:
205 // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
206 using R = typename SignedIntegerTypeForSize<sizeof(T)>::type;
207 return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
208 }
209 return GetComparator<T>()(a, b);
210 }
211
212 // Returns the Comparison::Type for the given primitive type. This assumes
213 // that each numerical representation follows the standard behavior, e.g.,
214 // integers are total order and floats are partial order.
215 [[deprecated("Use PrimitiveType and Order")]] static Comparison::Type
216 DefaultComparisonType(PrimitiveType type);
217
218 private:
219 // The direction of the Comparison, e.g., GT.
220 const Direction dir_;
221 // The primitive type of the Comparison operands, e.g., F32.
222 const PrimitiveType primitive_type_;
223 // The ordering of the Comparison, e.g., kPartial.
224 const Order order_;
225 // The Type of the Comparison. This tries to mesh together the ordering and
226 // the numerical data classification.
227 [[deprecated]] const Type type_;
228 };
229
230 using ComparisonDirection = Comparison::Direction;
231 using ComparisonOrder = Comparison::Order;
232
233 inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
234 return os << cmp.ToString();
235 }
236
237 std::string ComparisonDirectionToString(Comparison::Direction direction);
238 std::string ComparisonTypeToString(Comparison::Type type);
239 std::string ComparisonPrimitiveTypeToString(PrimitiveType type);
240 std::string ComparisonOrderToString(Comparison::Order order);
241
242 StatusOr<Comparison::Direction> StringToComparisonDirection(
243 absl::string_view direction);
244 StatusOr<Comparison::Type> StringToComparisonType(absl::string_view comparison);
245 StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order);
246
247 // Returns a comparison function using the provided key function on each value,
248 // i.e. `key_fn(a) < key_fn(b)`.
249 template <typename KeyFn>
LessThanByKey(KeyFn && key_fn)250 auto LessThanByKey(KeyFn&& key_fn) {
251 return [=](const auto& a, const auto& b) { return key_fn(a) < key_fn(b); };
252 }
253
254 } // namespace xla
255 #endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
256