1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/primitive_util.h"
17
18 #include <limits>
19
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/numbers.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28 namespace primitive_util {
29
SignificandWidth(PrimitiveType type)30 int SignificandWidth(PrimitiveType type) {
31 switch (type) {
32 case F32:
33 return std::numeric_limits<float>::digits;
34 case F64:
35 return std::numeric_limits<double>::digits;
36 case BF16:
37 return std::numeric_limits<bfloat16>::digits;
38 case F16:
39 return std::numeric_limits<half>::digits;
40 default:
41 LOG(FATAL) << "Not a floating data type " << type;
42 }
43 }
44
ExponentWidth(PrimitiveType type)45 int ExponentWidth(PrimitiveType type) {
46 // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
47 // biased exponent and a trailing significand field.
48 int total_bit_width = BitWidth(type);
49 // This field contains all bits in the significand other than the leading
50 // digit which is implied by the exponent.
51 int trailing_significand_field_width = SignificandWidth(type) - 1;
52 // The sign is encoded with a single bit.
53 int kSignBitWidth = 1;
54 // The remaining bits are used for encoding the biased exponent.
55 return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
56 }
57
OverflowExponent(PrimitiveType type)58 int OverflowExponent(PrimitiveType type) {
59 // |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
60 // integer such that radix raised to the power one less than that integer is a
61 // representable finite floating-point number." as such it does not actually
62 // yield the maximum exponent but the exponent of the first integer which
63 // overflows.
64 switch (type) {
65 case F32:
66 return std::numeric_limits<float>::max_exponent;
67 case F64:
68 return std::numeric_limits<double>::max_exponent;
69 case BF16:
70 return std::numeric_limits<bfloat16>::max_exponent;
71 case F16:
72 return std::numeric_limits<half>::max_exponent;
73 default:
74 LOG(FATAL) << "Not a floating data type " << type;
75 }
76 }
77
IsFloatingPointType(PrimitiveType type)78 bool IsFloatingPointType(PrimitiveType type) {
79 return type == F16 || type == F32 || type == F64 || type == BF16;
80 }
81
IsComplexType(PrimitiveType type)82 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
83
IsSignedIntegralType(PrimitiveType type)84 bool IsSignedIntegralType(PrimitiveType type) {
85 return type == S8 || type == S16 || type == S32 || type == S64;
86 }
87
IsUnsignedIntegralType(PrimitiveType type)88 bool IsUnsignedIntegralType(PrimitiveType type) {
89 return type == U8 || type == U16 || type == U32 || type == U64;
90 }
91
IsIntegralType(PrimitiveType type)92 bool IsIntegralType(PrimitiveType type) {
93 return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
94 }
95
BitWidth(PrimitiveType type)96 int BitWidth(PrimitiveType type) {
97 switch (type) {
98 case PRED:
99 return 1;
100
101 case S8:
102 case U8:
103 return 8;
104
105 case S16:
106 case U16:
107 case F16:
108 case BF16:
109 return 16;
110
111 case U32:
112 case S32:
113 case F32:
114 return 32;
115
116 case U64:
117 case S64:
118 case F64:
119 case C64:
120 return 64;
121
122 case C128:
123 return 128;
124
125 case TUPLE:
126 LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
127
128 case OPAQUE_TYPE:
129 LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
130
131 default:
132 LOG(FATAL) << "Unhandled primitive type " << type;
133 }
134 }
135
UnsignedIntegralTypeForBitWidth(int64 src_bitwidth)136 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
137 switch (src_bitwidth) {
138 case 8:
139 return xla::U8;
140 case 16:
141 return xla::U16;
142 case 32:
143 return xla::U32;
144 case 64:
145 return xla::U64;
146 default:
147 return xla::PRIMITIVE_TYPE_INVALID;
148 }
149 }
150
SignedIntegralTypeForBitWidth(int64 src_bitwidth)151 xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) {
152 switch (src_bitwidth) {
153 case 8:
154 return xla::S8;
155 case 16:
156 return xla::S16;
157 case 32:
158 return xla::S32;
159 case 64:
160 return xla::S64;
161 default:
162 return xla::PRIMITIVE_TYPE_INVALID;
163 }
164 }
165
ComplexComponentType(PrimitiveType complex_type)166 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
167 switch (complex_type) {
168 case C64:
169 return F32;
170 case C128:
171 return F64;
172 default:
173 LOG(FATAL) << "Primitive type is not complex: "
174 << PrimitiveType_Name(complex_type);
175 }
176 }
177
IsArrayType(PrimitiveType primitive_type)178 bool IsArrayType(PrimitiveType primitive_type) {
179 return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
180 primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
181 }
182
183 // Class to memoize the computation of
184 // absl::AsciiStrToLower(PrimitiveType_Name(p))
185 // for all PrimitiveType values "p"
186 //
187 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
188 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
189 class PrimitiveTypeNameGenerator {
190 public:
PrimitiveTypeNameGenerator()191 PrimitiveTypeNameGenerator() {
192 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
193 if (i == static_cast<int>(OPAQUE_TYPE)) {
194 lowercase_name_[i] = "opaque";
195 } else if (PrimitiveType_IsValid(i)) {
196 lowercase_name_[i] = absl::AsciiStrToLower(
197 PrimitiveType_Name(static_cast<PrimitiveType>(i)));
198 }
199 }
200 }
LowercaseName(PrimitiveType t)201 const string& LowercaseName(PrimitiveType t) {
202 return lowercase_name_[static_cast<int>(t)];
203 }
204
205 private:
206 string lowercase_name_[PrimitiveType_ARRAYSIZE];
207 };
208
LowercasePrimitiveTypeName(PrimitiveType s)209 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
210 static auto* gen = new PrimitiveTypeNameGenerator();
211 return gen->LowercaseName(s);
212 }
213
214 namespace {
215
216 // Returns a map from lower-case primitive type name to primitive type.
217 //
218 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
219 // the xla::OPAQUE_TYPE enumerator.
GetPrimitiveTypeStringMap()220 const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
221 static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
222 static auto* map = new std::unordered_map<string, PrimitiveType>;
223 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
224 if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
225 auto value = static_cast<PrimitiveType>(i);
226 (*map)[LowercasePrimitiveTypeName(value)] = value;
227 }
228 }
229 (*map)["opaque"] = OPAQUE_TYPE;
230 return map;
231 }();
232 return *name_to_type;
233 }
234
235 } // namespace
236
StringToPrimitiveType(absl::string_view name)237 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
238 const auto& map = GetPrimitiveTypeStringMap();
239 auto found = map.find(string(name));
240 if (found == map.end()) {
241 return InvalidArgument("Invalid element type string: \"%s\".", name);
242 }
243 return found->second;
244 }
245
IsPrimitiveTypeName(absl::string_view name)246 bool IsPrimitiveTypeName(absl::string_view name) {
247 const auto& map = GetPrimitiveTypeStringMap();
248 auto found = map.find(string(name));
249 return found != map.end();
250 }
251
252 } // namespace primitive_util
253 } // namespace xla
254