• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
18 
19 #include <optional>
20 #include <string>
21 #include <type_traits>
22 
23 #include "absl/base/attributes.h"
24 #include "absl/base/macros.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/bfloat16.h"
31 
32 namespace xla {
33 
34 // A utility class for primitive comparisons. A comparison includes three
35 // components: the type of the elements being compared (F32, S16, etc), whether
36 // it is a partial or total order comparison, and the actual comparison operator
37 // (==, <=, >, etc).
38 //
39 // Note that integer comparisons are always total order. Float comparisons can
40 // be either total or partial order.
41 //
42 // Some examples:
43 //
44 //   Comparison a(
45 //     Comparison::Direction::kLt,
46 //     xla::PrimitiveType::BF16,
47 //     Comparison::Order::kTotal
48 //   );
49 //   a.ToString(); /* ".LT.BF16.TOTALORDER" */
50 //
51 //   Comparison b(Comparison::Direction::kEq, xla::PrimitiveType::U32);
52 //   b.IsTotalOrder(); /* true */
53 class Comparison {
54  public:
55   // Represents the ordering of the comparison.
56   enum class Order : uint8_t {
57     // https://en.wikipedia.org/wiki/Total_order
58     kTotal,
59     // https://en.wikipedia.org/wiki/Partially_ordered_set
60     kPartial,
61   };
62 
63   // Represents different comparison operations.
64   enum class Direction : uint8_t {
65     kEq,
66     kNe,
67     kGe,
68     kGt,
69     kLe,
70     kLt,
71   };
72 
73   // (DEPRECATED) Represents the type of comparison. Prefer xla::PrimitiveType
74   // and Comparison::Order, since there are multiple floating point
75   // representations that support total ordering.
76   enum class [[deprecated("Use PrimitiveType and Order")]] Type : uint8_t{
77       kFloat,
78       kFloatTotalOrder,
79       kSigned,
80       kUnsigned,
81   };
82 
83   Comparison() = delete;
84 
85   // This will default to the expected behavior for Comparison::Order: integers
86   // will use total ordering, and floats will use partial ordering.
87   explicit Comparison(Direction dir, PrimitiveType type);
88 
89   // Pass in a Comparison::Order to specify a non-default ordering, e.g., some
90   // targets may support total order floating point type comparisons.
91   explicit Comparison(Direction dir, PrimitiveType type, Order order);
92 
93   // Returns a comparison with a primitive type matching the Comparison::Type
94   // and using a default bit width of 32. For example,
95   // Comparison(Direction::kLt, Type::kFloat).PrimitiveType()  /* F32 */
96   [[deprecated(
97       "Use Comparison(Comparison::Direction, "
98       "PrimitiveType)")]] explicit Comparison(Direction dir, Type type);
99 
GetDirection()100   inline Direction GetDirection() const { return dir_; }
GetPrimitiveType()101   inline PrimitiveType GetPrimitiveType() const { return primitive_type_; }
GetOrder()102   inline Order GetOrder() const { return order_; }
103 
GetType()104   [[deprecated("Use GetPrimitiveType() and GetOrder()")]] inline Type GetType()
105       const {
106     return type_;
107   }
108 
IsEq()109   inline bool IsEq() const { return dir_ == Direction::kEq; }
IsNe()110   inline bool IsNe() const { return dir_ == Direction::kNe; }
IsGe()111   inline bool IsGe() const { return dir_ == Direction::kGe; }
IsGt()112   inline bool IsGt() const { return dir_ == Direction::kGt; }
IsLt()113   inline bool IsLt() const { return dir_ == Direction::kLt; }
IsTotalOrder()114   inline bool IsTotalOrder() const { return order_ == Order::kTotal; }
IsPartialOrder()115   inline bool IsPartialOrder() const { return order_ == Order::kPartial; }
116 
117   // Returns whether this is a floating point total order comparison.
IsF32TotalOrder()118   inline bool IsF32TotalOrder() const {
119     return primitive_type_ == PrimitiveType::F32 && IsTotalOrder();
120   }
IsBf16TotalOrder()121   inline bool IsBf16TotalOrder() const {
122     return primitive_type_ == PrimitiveType::BF16 && IsTotalOrder();
123   }
124 
125   // Returns whether this is a standard comparison, i.e., what you would expect
126   // as the industry standard on most architectures.
IsStandardF32()127   inline bool IsStandardF32() const {
128     return primitive_type_ == PrimitiveType::F32 && IsPartialOrder();
129   }
IsStandardBf16()130   inline bool IsStandardBf16() const {
131     return primitive_type_ == PrimitiveType::BF16 && IsPartialOrder();
132   }
IsStandardS32()133   inline bool IsStandardS32() const {
134     return primitive_type_ == PrimitiveType::S32 && IsTotalOrder();
135   }
IsStandardU32()136   inline bool IsStandardU32() const {
137     return primitive_type_ == PrimitiveType::U32 && IsTotalOrder();
138   }
139 
IsIntegralPrimitiveType()140   inline bool IsIntegralPrimitiveType() const {
141     return primitive_util::IsIntegralType(primitive_type_);
142   }
IsFloatingPointPrimitiveType()143   inline bool IsFloatingPointPrimitiveType() const {
144     return primitive_util::IsFloatingPointType(primitive_type_);
145   }
146 
147   // Returns whether (a dir a) is always true for this comparison.
148   bool IsReflexive() const;
149 
150   // Returns whether (a dir a) is always false for this comparison.
151   bool IsAntireflexive() const;
152 
153   // Gets the converse of the given comparison direction (e.g. >= turns to <=).
154   // Useful when commuting operands to get constants into immediate-accepting
155   // positions in the ISA.
156   Comparison Converse() const;
157 
158   // Gets the inverse of the given comparison if it exists (e.g. >= turns to <).
159   // Returns optional value because not all inversions may be supported.
160   std::optional<Comparison> Inverse() const;
161 
162   // Returns a string version of this comparison, e.g., ".GT.F32.TOTALORDER"
163   std::string ToString(std::string prefix1 = ".", std::string prefix2 = ".",
164                        std::string prefix3 = ".") const;
165 
166   // Returns a comparison operator: (T, T) -> bool for this Comparison's
167   // Direction.
168   template <typename T>
GetComparator()169   std::function<bool(T, T)> GetComparator() const {
170     switch (GetDirection()) {
171       case Direction::kEq:
172         return std::equal_to<T>();
173       case Direction::kNe:
174         return std::not_equal_to<T>();
175       case Direction::kGe:
176         return std::greater_equal<T>();
177       case Direction::kGt:
178         return std::greater<T>();
179       case Direction::kLe:
180         return std::less_equal<T>();
181       case Direction::kLt:
182         return std::less<T>();
183     }
184   }
185 
186   // Applies the comparison from this Comparison's direction and ordering for
187   // integral types.
188   template <typename T, absl::enable_if_t<std::is_integral<T>::value, int> = 0>
Compare(const T a,const T b)189   bool Compare(const T a, const T b) const {
190     CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_));
191     return GetComparator<T>()(a, b);
192   }
193 
194   // Applies the comparison from this Comparison's direction and ordering
195   // for floating point types.
196   template <typename T,
197             absl::enable_if_t<std::is_floating_point<T>::value ||
198                                   std::is_same<T, xla::bfloat16>::value,
199                               int> = 0>
Compare(const T a,const T b)200   bool Compare(const T a, const T b) const {
201     CHECK(primitive_util::IsCanonicalRepresentation<T>(primitive_type_));
202     if (IsTotalOrder()) {
203       //  -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
204       // Reference:
205       // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
206       using R = typename SignedIntegerTypeForSize<sizeof(T)>::type;
207       return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
208     }
209     return GetComparator<T>()(a, b);
210   }
211 
212   // Returns the Comparison::Type for the given primitive type. This assumes
213   // that each numerical representation follows the standard behavior, e.g.,
214   // integers are total order and floats are partial order.
215   [[deprecated("Use PrimitiveType and Order")]] static Comparison::Type
216   DefaultComparisonType(PrimitiveType type);
217 
218  private:
219   // The direction of the Comparison, e.g., GT.
220   const Direction dir_;
221   // The primitive type of the Comparison operands, e.g., F32.
222   const PrimitiveType primitive_type_;
223   // The ordering of the Comparison, e.g., kPartial.
224   const Order order_;
225   // The Type of the Comparison. This tries to mesh together the ordering and
226   // the numerical data classification.
227   [[deprecated]] const Type type_;
228 };
229 
230 using ComparisonDirection = Comparison::Direction;
231 using ComparisonOrder = Comparison::Order;
232 
233 inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
234   return os << cmp.ToString();
235 }
236 
237 std::string ComparisonDirectionToString(Comparison::Direction direction);
238 std::string ComparisonTypeToString(Comparison::Type type);
239 std::string ComparisonPrimitiveTypeToString(PrimitiveType type);
240 std::string ComparisonOrderToString(Comparison::Order order);
241 
242 StatusOr<Comparison::Direction> StringToComparisonDirection(
243     absl::string_view direction);
244 StatusOr<Comparison::Type> StringToComparisonType(absl::string_view comparison);
245 StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order);
246 
247 // Returns a comparison function using the provided key function on each value,
248 // i.e. `key_fn(a) < key_fn(b)`.
249 template <typename KeyFn>
LessThanByKey(KeyFn && key_fn)250 auto LessThanByKey(KeyFn&& key_fn) {
251   return [=](const auto& a, const auto& b) { return key_fn(a) < key_fn(b); };
252 }
253 
254 }  // namespace xla
255 #endif  // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
256