1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/primitive_util.h"
17
18 #include <limits>
19
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/numbers.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28 namespace primitive_util {
29
SignificandWidth(PrimitiveType type)30 int SignificandWidth(PrimitiveType type) {
31 switch (type) {
32 case F32:
33 return std::numeric_limits<float>::digits;
34 case F64:
35 return std::numeric_limits<double>::digits;
36 case BF16:
37 return std::numeric_limits<bfloat16>::digits;
38 case F16:
39 return std::numeric_limits<half>::digits;
40 default:
41 LOG(FATAL) << "Not a floating data type " << type;
42 }
43 }
44
ExponentWidth(PrimitiveType type)45 int ExponentWidth(PrimitiveType type) {
46 // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
47 // biased exponent and a trailing significand field.
48 int total_bit_width = BitWidth(type);
49 // This field contains all bits in the significand other than the leading
50 // digit which is implied by the exponent.
51 int trailing_significand_field_width = SignificandWidth(type) - 1;
52 // The sign is encoded with a single bit.
53 int kSignBitWidth = 1;
54 // The remaining bits are used for encoding the biased exponent.
55 return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
56 }
57
OverflowExponent(PrimitiveType type)58 int OverflowExponent(PrimitiveType type) {
59 // |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
60 // integer such that radix raised to the power one less than that integer is a
61 // representable finite floating-point number." as such it does not actually
62 // yield the maximum exponent but the exponent of the first integer which
63 // overflows.
64 switch (type) {
65 case F32:
66 return std::numeric_limits<float>::max_exponent;
67 case F64:
68 return std::numeric_limits<double>::max_exponent;
69 case BF16:
70 return std::numeric_limits<bfloat16>::max_exponent;
71 case F16:
72 return std::numeric_limits<half>::max_exponent;
73 default:
74 LOG(FATAL) << "Not a floating data type " << type;
75 }
76 }
77
IsFloatingPointType(PrimitiveType type)78 bool IsFloatingPointType(PrimitiveType type) {
79 return type == F16 || type == F32 || type == F64 || type == BF16;
80 }
81
IsComplexType(PrimitiveType type)82 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
83
IsSignedIntegralType(PrimitiveType type)84 bool IsSignedIntegralType(PrimitiveType type) {
85 return type == S8 || type == S16 || type == S32 || type == S64;
86 }
87
IsUnsignedIntegralType(PrimitiveType type)88 bool IsUnsignedIntegralType(PrimitiveType type) {
89 return type == U8 || type == U16 || type == U32 || type == U64;
90 }
91
IsIntegralType(PrimitiveType type)92 bool IsIntegralType(PrimitiveType type) {
93 return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
94 }
95
BitWidth(PrimitiveType type)96 int BitWidth(PrimitiveType type) {
97 switch (type) {
98 case PRED:
99 return 1;
100
101 case S8:
102 case U8:
103 return 8;
104
105 case S16:
106 case U16:
107 case F16:
108 case BF16:
109 return 16;
110
111 case U32:
112 case S32:
113 case F32:
114 return 32;
115
116 case U64:
117 case S64:
118 case F64:
119 case C64:
120 return 64;
121
122 case C128:
123 return 128;
124
125 case TUPLE:
126 LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
127
128 case OPAQUE_TYPE:
129 LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
130
131 default:
132 LOG(FATAL) << "Unhandled primitive type " << type;
133 }
134 }
135
ByteWidth(PrimitiveType type)136 int ByteWidth(PrimitiveType type) { return CeilOfRatio(BitWidth(type), 8); }
137
UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth)138 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
139 switch (src_bitwidth) {
140 case 8:
141 return xla::U8;
142 case 16:
143 return xla::U16;
144 case 32:
145 return xla::U32;
146 case 64:
147 return xla::U64;
148 default:
149 return xla::PRIMITIVE_TYPE_INVALID;
150 }
151 }
152
SignedIntegralTypeForBitWidth(int64_t src_bitwidth)153 xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
154 switch (src_bitwidth) {
155 case 8:
156 return xla::S8;
157 case 16:
158 return xla::S16;
159 case 32:
160 return xla::S32;
161 case 64:
162 return xla::S64;
163 default:
164 return xla::PRIMITIVE_TYPE_INVALID;
165 }
166 }
167
ComplexComponentType(PrimitiveType complex_type)168 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
169 switch (complex_type) {
170 case C64:
171 return F32;
172 case C128:
173 return F64;
174 default:
175 LOG(FATAL) << "Primitive type is not complex: "
176 << PrimitiveType_Name(complex_type);
177 }
178 }
179
IsArrayType(PrimitiveType primitive_type)180 bool IsArrayType(PrimitiveType primitive_type) {
181 return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
182 primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
183 }
184
185 // Class to memoize the computation of
186 // absl::AsciiStrToLower(PrimitiveType_Name(p))
187 // for all PrimitiveType values "p"
188 //
189 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
190 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
191 class PrimitiveTypeNameGenerator {
192 public:
PrimitiveTypeNameGenerator()193 PrimitiveTypeNameGenerator() {
194 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
195 if (i == static_cast<int>(OPAQUE_TYPE)) {
196 lowercase_name_[i] = "opaque";
197 } else if (PrimitiveType_IsValid(i)) {
198 lowercase_name_[i] = absl::AsciiStrToLower(
199 PrimitiveType_Name(static_cast<PrimitiveType>(i)));
200 }
201 }
202 }
LowercaseName(PrimitiveType t)203 const string& LowercaseName(PrimitiveType t) {
204 return lowercase_name_[static_cast<int>(t)];
205 }
206
207 private:
208 string lowercase_name_[PrimitiveType_ARRAYSIZE];
209 };
210
LowercasePrimitiveTypeName(PrimitiveType s)211 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
212 static auto* gen = new PrimitiveTypeNameGenerator();
213 return gen->LowercaseName(s);
214 }
215
216 namespace {
217
218 // Returns a map from lower-case primitive type name to primitive type.
219 //
220 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
221 // the xla::OPAQUE_TYPE enumerator.
GetPrimitiveTypeStringMap()222 const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
223 static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
224 static auto* map = new std::unordered_map<string, PrimitiveType>;
225 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
226 if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
227 auto value = static_cast<PrimitiveType>(i);
228 (*map)[LowercasePrimitiveTypeName(value)] = value;
229 }
230 }
231 (*map)["opaque"] = OPAQUE_TYPE;
232 return map;
233 }();
234 return *name_to_type;
235 }
236
237 } // namespace
238
StringToPrimitiveType(absl::string_view name)239 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
240 const auto& map = GetPrimitiveTypeStringMap();
241 auto found = map.find(string(name));
242 if (found == map.end()) {
243 return InvalidArgument("Invalid element type string: \"%s\".", name);
244 }
245 return found->second;
246 }
247
IsPrimitiveTypeName(absl::string_view name)248 bool IsPrimitiveTypeName(absl::string_view name) {
249 const auto& map = GetPrimitiveTypeStringMap();
250 auto found = map.find(string(name));
251 return found != map.end();
252 }
253
254 } // namespace primitive_util
255 } // namespace xla
256