1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/comparison_util.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "tensorflow/compiler/xla/util.h"
19
20 namespace xla {
21
ComparisonDirectionToString(Comparison::Direction direction)22 std::string ComparisonDirectionToString(Comparison::Direction direction) {
23 switch (direction) {
24 case Comparison::Direction::kEq:
25 return "EQ";
26 case Comparison::Direction::kNe:
27 return "NE";
28 case Comparison::Direction::kGe:
29 return "GE";
30 case Comparison::Direction::kGt:
31 return "GT";
32 case Comparison::Direction::kLe:
33 return "LE";
34 case Comparison::Direction::kLt:
35 return "LT";
36 default:
37 LOG(FATAL) << "Attempted to print uninitialized comparison direction";
38 }
39 }
40
StringToComparisonDirection(absl::string_view direction_name)41 StatusOr<Comparison::Direction> StringToComparisonDirection(
42 absl::string_view direction_name) {
43 static auto* direction_map =
44 new absl::flat_hash_map<string, Comparison::Direction>({
45 {"EQ", Comparison::Direction::kEq},
46 {"NE", Comparison::Direction::kNe},
47 {"GE", Comparison::Direction::kGe},
48 {"GT", Comparison::Direction::kGt},
49 {"LE", Comparison::Direction::kLe},
50 {"LT", Comparison::Direction::kLt},
51 });
52 auto it = direction_map->find(direction_name);
53 if (it == direction_map->end()) {
54 return InvalidArgument("Unknown comparison direction: %s", direction_name);
55 }
56 return it->second;
57 }
58
StringToComparisonType(absl::string_view compare_type_name)59 StatusOr<Comparison::Type> StringToComparisonType(
60 absl::string_view compare_type_name) {
61 static auto* type_map = new absl::flat_hash_map<string, Comparison::Type>({
62 {"FLOAT", Comparison::Type::kFloat},
63 {"TOTALORDER", Comparison::Type::kFloatTotalOrder},
64 {"SIGNED", Comparison::Type::kSigned},
65 {"UNSIGNED", Comparison::Type::kUnsigned},
66 });
67 auto it = type_map->find(compare_type_name);
68 if (it == type_map->end()) {
69 return InvalidArgument("Unknown comparison type: %s", compare_type_name);
70 }
71 return it->second;
72 }
73
ComparisonTypeToString(Comparison::Type type)74 std::string ComparisonTypeToString(Comparison::Type type) {
75 switch (type) {
76 case Comparison::Type::kFloat:
77 return "FLOAT";
78 case Comparison::Type::kFloatTotalOrder:
79 return "TOTALORDER";
80 case Comparison::Type::kSigned:
81 return "SIGNED";
82 case Comparison::Type::kUnsigned:
83 return "UNSIGNED";
84 default:
85 LOG(FATAL) << "Attempted to print incomplete comparison type";
86 }
87 }
88
Comparison(Direction dir,PrimitiveType type)89 Comparison::Comparison(Direction dir, PrimitiveType type)
90 : dir_(dir), type_(DefaultComparisonType(type)) {}
91
DefaultComparisonType(PrimitiveType type)92 Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
93 switch (type) {
94 case S8:
95 case S16:
96 case S32:
97 case S64:
98 return Type::kSigned;
99 case PRED:
100 case U8:
101 case U16:
102 case U32:
103 case U64:
104 return Type::kUnsigned;
105 case F16:
106 case F32:
107 case BF16:
108 case F64:
109 case C64:
110 case C128:
111 return Type::kFloat;
112 default:
113 LOG(FATAL) << "Unsupported comparison mode."
114 << PrimitiveType_Name(type) << "\n";
115 }
116 }
117
Converse() const118 Comparison Comparison::Converse() const {
119 return Comparison(Converse(dir_), type_);
120 }
121
Inverse() const122 absl::optional<Comparison> Comparison::Inverse() const {
123 switch (type_) {
124 case Type::kFloat:
125 // Floating-point comparisons don't have inverses unless total order is
126 // supported (e.g. comparison can return true if one operand is NaN).
127 return absl::nullopt;
128 case Type::kFloatTotalOrder:
129 case Type::kSigned:
130 case Type::kUnsigned:
131 return Comparison(Inverse(dir_), type_);
132 }
133 }
134
IsReflexive() const135 bool Comparison::IsReflexive() const {
136 switch (dir_) {
137 case Direction::kEq:
138 case Direction::kGe:
139 case Direction::kLe:
140 return IsSigned() || IsUnsigned() || IsFloatTotalOrder();
141 case Direction::kNe:
142 case Direction::kGt:
143 case Direction::kLt:
144 return false;
145 }
146 }
147
IsAntireflexive() const148 bool Comparison::IsAntireflexive() const {
149 switch (dir_) {
150 case Direction::kNe:
151 return IsSigned() || IsUnsigned() || IsFloatTotalOrder();
152 case Direction::kGt:
153 case Direction::kLt:
154 return true;
155 case Direction::kEq:
156 case Direction::kGe:
157 case Direction::kLe:
158 return false;
159 }
160 }
161
Converse(Comparison::Direction dir)162 /* static */ Comparison::Direction Comparison::Converse(
163 Comparison::Direction dir) {
164 switch (dir) {
165 case Comparison::Direction::kEq:
166 return Comparison::Direction::kEq;
167 case Comparison::Direction::kNe:
168 return Comparison::Direction::kNe;
169 case Comparison::Direction::kGe:
170 return Comparison::Direction::kLe;
171 case Comparison::Direction::kGt:
172 return Comparison::Direction::kLt;
173 case Comparison::Direction::kLe:
174 return Comparison::Direction::kGe;
175 case Comparison::Direction::kLt:
176 return Comparison::Direction::kGt;
177 }
178 }
179
Inverse(Comparison::Direction dir)180 /* static */ Comparison::Direction Comparison::Inverse(
181 Comparison::Direction dir) {
182 switch (dir) {
183 case Direction::kEq:
184 return Direction::kNe;
185 case Direction::kNe:
186 return Direction::kEq;
187 case Direction::kGe:
188 return Direction::kLt;
189 case Direction::kGt:
190 return Direction::kLe;
191 case Direction::kLe:
192 return Direction::kGt;
193 case Direction::kLt:
194 return Direction::kGe;
195 }
196 }
197
ToString(std::string prefix1,std::string prefix2) const198 std::string Comparison::ToString(std::string prefix1,
199 std::string prefix2) const {
200 return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 +
201 std::string(ComparisonTypeToString(type_));
202 }
203 } // namespace xla
204