1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/primitive_util.h"
17
18 #include <limits>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/numbers.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace xla {
29 namespace primitive_util {
30
SignificandWidth(PrimitiveType type)31 int SignificandWidth(PrimitiveType type) {
32 switch (type) {
33 case F32:
34 return std::numeric_limits<float>::digits;
35 case F64:
36 return std::numeric_limits<double>::digits;
37 case BF16:
38 return std::numeric_limits<bfloat16>::digits;
39 case F16:
40 return std::numeric_limits<half>::digits;
41 default:
42 LOG(FATAL) << "Not a floating data type " << type;
43 }
44 }
45
ExponentWidth(PrimitiveType type)46 int ExponentWidth(PrimitiveType type) {
47 // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
48 // biased exponent and a trailing significand field.
49 int total_bit_width = BitWidth(type);
50 // This field contains all bits in the significand other than the leading
51 // digit which is implied by the exponent.
52 int trailing_significand_field_width = SignificandWidth(type) - 1;
53 // The sign is encoded with a single bit.
54 int kSignBitWidth = 1;
55 // The remaining bits are used for encoding the biased exponent.
56 return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
57 }
58
OverflowExponent(PrimitiveType type)59 int OverflowExponent(PrimitiveType type) {
60 // |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
61 // integer such that radix raised to the power one less than that integer is a
62 // representable finite floating-point number." as such it does not actually
63 // yield the maximum exponent but the exponent of the first integer which
64 // overflows.
65 switch (type) {
66 case F32:
67 return std::numeric_limits<float>::max_exponent;
68 case F64:
69 return std::numeric_limits<double>::max_exponent;
70 case BF16:
71 return std::numeric_limits<bfloat16>::max_exponent;
72 case F16:
73 return std::numeric_limits<half>::max_exponent;
74 default:
75 LOG(FATAL) << "Not a floating data type " << type;
76 }
77 }
78
IsFloatingPointType(PrimitiveType type)79 bool IsFloatingPointType(PrimitiveType type) {
80 return type == F16 || type == F32 || type == F64 || type == BF16;
81 }
82
IsComplexType(PrimitiveType type)83 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
84
IsSignedIntegralType(PrimitiveType type)85 bool IsSignedIntegralType(PrimitiveType type) {
86 return type == S8 || type == S16 || type == S32 || type == S64;
87 }
88
IsUnsignedIntegralType(PrimitiveType type)89 bool IsUnsignedIntegralType(PrimitiveType type) {
90 return type == U8 || type == U16 || type == U32 || type == U64;
91 }
92
IsIntegralType(PrimitiveType type)93 bool IsIntegralType(PrimitiveType type) {
94 return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
95 }
96
BitWidth(PrimitiveType type)97 int BitWidth(PrimitiveType type) {
98 switch (type) {
99 case PRED:
100 return 1;
101
102 case S8:
103 case U8:
104 return 8;
105
106 case S16:
107 case U16:
108 case F16:
109 case BF16:
110 return 16;
111
112 case U32:
113 case S32:
114 case F32:
115 return 32;
116
117 case U64:
118 case S64:
119 case F64:
120 case C64:
121 return 64;
122
123 case C128:
124 return 128;
125
126 case TUPLE:
127 LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
128
129 case OPAQUE_TYPE:
130 LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
131
132 default:
133 LOG(FATAL) << "Unhandled primitive type " << type;
134 }
135 }
136
ByteWidth(PrimitiveType type)137 int ByteWidth(PrimitiveType type) { return CeilOfRatio(BitWidth(type), 8); }
138
UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth)139 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
140 switch (src_bitwidth) {
141 case 8:
142 return xla::U8;
143 case 16:
144 return xla::U16;
145 case 32:
146 return xla::U32;
147 case 64:
148 return xla::U64;
149 default:
150 return xla::PRIMITIVE_TYPE_INVALID;
151 }
152 }
153
SignedIntegralTypeForBitWidth(int64_t src_bitwidth)154 xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
155 switch (src_bitwidth) {
156 case 8:
157 return xla::S8;
158 case 16:
159 return xla::S16;
160 case 32:
161 return xla::S32;
162 case 64:
163 return xla::S64;
164 default:
165 return xla::PRIMITIVE_TYPE_INVALID;
166 }
167 }
168
ComplexComponentType(PrimitiveType complex_type)169 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
170 switch (complex_type) {
171 case C64:
172 return F32;
173 case C128:
174 return F64;
175 default:
176 LOG(FATAL) << "Primitive type is not complex: "
177 << PrimitiveType_Name(complex_type);
178 }
179 }
180
181 // Class to memoize the computation of
182 // absl::AsciiStrToLower(PrimitiveType_Name(p))
183 // for all PrimitiveType values "p"
184 //
185 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
186 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
187 class PrimitiveTypeNameGenerator {
188 public:
PrimitiveTypeNameGenerator()189 PrimitiveTypeNameGenerator() {
190 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
191 if (i == static_cast<int>(OPAQUE_TYPE)) {
192 lowercase_name_[i] = "opaque";
193 } else if (PrimitiveType_IsValid(i)) {
194 lowercase_name_[i] = absl::AsciiStrToLower(
195 PrimitiveType_Name(static_cast<PrimitiveType>(i)));
196 }
197 }
198 }
LowercaseName(PrimitiveType t)199 const std::string& LowercaseName(PrimitiveType t) {
200 return lowercase_name_[static_cast<int>(t)];
201 }
202
203 private:
204 std::string lowercase_name_[PrimitiveType_ARRAYSIZE];
205 };
206
LowercasePrimitiveTypeName(PrimitiveType s)207 const std::string& LowercasePrimitiveTypeName(PrimitiveType s) {
208 static auto* gen = new PrimitiveTypeNameGenerator();
209 return gen->LowercaseName(s);
210 }
211
212 namespace {
213
214 // Returns a map from lower-case primitive type name to primitive type.
215 //
216 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
217 // the xla::OPAQUE_TYPE enumerator.
218 const absl::flat_hash_map<std::string, PrimitiveType>&
GetPrimitiveTypeStringMap()219 GetPrimitiveTypeStringMap() {
220 static absl::flat_hash_map<std::string, PrimitiveType>* name_to_type = [] {
221 static auto* map = new absl::flat_hash_map<std::string, PrimitiveType>;
222 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
223 if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
224 auto value = static_cast<PrimitiveType>(i);
225 (*map)[LowercasePrimitiveTypeName(value)] = value;
226 }
227 }
228 (*map)["opaque"] = OPAQUE_TYPE;
229 return map;
230 }();
231 return *name_to_type;
232 }
233
234 } // namespace
235
StringToPrimitiveType(absl::string_view name)236 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
237 const auto& map = GetPrimitiveTypeStringMap();
238 auto found = map.find(std::string(name));
239 if (found == map.end()) {
240 return InvalidArgument("Invalid element type string: \"%s\".", name);
241 }
242 return found->second;
243 }
244
IsPrimitiveTypeName(absl::string_view name)245 bool IsPrimitiveTypeName(absl::string_view name) {
246 const auto& map = GetPrimitiveTypeStringMap();
247 auto found = map.find(std::string(name));
248 return found != map.end();
249 }
250
251 } // namespace primitive_util
252 } // namespace xla
253