1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with XLA primitive types.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20
21 #include <string>
22 #include <type_traits>
23
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31 namespace primitive_util {
32
33 // Returns the count of significand (mantissa) bits for float datatypes.
34 // For non-float datatypes, results in a LOG(FATAL).
35 int SignificandWidth(PrimitiveType type);
36
37 // Returns the count of exponent bits for float datatypes.
38 // For non-float datatypes, results in a LOG(FATAL).
39 int ExponentWidth(PrimitiveType type);
40
41 // Returns the exponent of the smallest number which cannot be represented.
42 // For non-float datatypes, results in a LOG(FATAL).
43 int OverflowExponent(PrimitiveType type);
44
45 // Returns the XLA primitive type (eg, F32) corresponding to the given
46 // template parameter native type (eg, float).
47 template <typename NativeT>
NativeToPrimitiveType()48 PrimitiveType NativeToPrimitiveType() {
49 // Make the expression depend on the template parameter NativeT so
50 // that this compile-time error only appears if this function is
51 // instantiated with some concrete type that is not specialized
52 // below.
53 static_assert(!std::is_same<NativeT, NativeT>::value,
54 "Cannot map native type to primitive type.");
55 return PRIMITIVE_TYPE_INVALID;
56 }
57
58 // Declarations of specializations for each native type which correspond to a
59 // XLA primitive type. As an optimization, these are declared inline in the
60 // header.
61 template <>
62 inline PrimitiveType NativeToPrimitiveType<bool>() {
63 return PRED;
64 }
65
66 // Unsigned integer
67 template <>
68 inline PrimitiveType NativeToPrimitiveType<uint8_t>() {
69 return U8;
70 }
71
72 template <>
73 inline PrimitiveType NativeToPrimitiveType<uint16_t>() {
74 return U16;
75 }
76
77 template <>
78 inline PrimitiveType NativeToPrimitiveType<uint32_t>() {
79 return U32;
80 }
81
82 template <>
83 inline PrimitiveType NativeToPrimitiveType<uint64_t>() {
84 return U64;
85 }
86
87 // Signed integer
88 template <>
89 inline PrimitiveType NativeToPrimitiveType<int8_t>() {
90 return S8;
91 }
92
93 template <>
94 inline PrimitiveType NativeToPrimitiveType<int16_t>() {
95 return S16;
96 }
97
98 template <>
99 inline PrimitiveType NativeToPrimitiveType<int32_t>() {
100 return S32;
101 }
102
103 template <>
104 inline PrimitiveType NativeToPrimitiveType<int64_t>() {
105 return S64;
106 }
107
108 // Floating point
109 template <>
110 inline PrimitiveType NativeToPrimitiveType<float>() {
111 return F32;
112 }
113
114 template <>
115 inline PrimitiveType NativeToPrimitiveType<double>() {
116 return F64;
117 }
118
119 template <>
120 inline PrimitiveType NativeToPrimitiveType<half>() {
121 return F16;
122 }
123
124 template <>
125 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
126 return BF16;
127 }
128
129 // Complex
130 template <>
131 inline PrimitiveType NativeToPrimitiveType<complex64>() {
132 return C64;
133 }
134
135 template <>
136 inline PrimitiveType NativeToPrimitiveType<complex128>() {
137 return C128;
138 }
139
140 bool IsFloatingPointType(PrimitiveType type);
141
142 bool IsComplexType(PrimitiveType type);
143
144 bool IsSignedIntegralType(PrimitiveType type);
145
146 bool IsUnsignedIntegralType(PrimitiveType type);
147
148 bool IsIntegralType(PrimitiveType type);
149
150 // Returns true if values of the given primitive type are held in array shapes.
IsArrayType(PrimitiveType primitive_type)151 inline constexpr bool IsArrayType(PrimitiveType primitive_type) {
152 return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
153 primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
154 }
155
156 // Returns the number of bits in the representation for a given type.
157 int BitWidth(PrimitiveType type);
158
159 // Returns the number of bytes in the representation for a given type.
160 int ByteWidth(PrimitiveType type);
161
162 PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth);
163
164 PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth);
165
166 // Returns the real, imag component type underlying the given complex type.
167 // LOG(FATAL)'s if complex_type is not complex.
168 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
169
170 // Returns the higher-precision element type if a and b are both floating
171 // point types; otherwise, checks that they have the same element type
172 // and returns it.
HigherPrecisionType(PrimitiveType a,PrimitiveType b)173 inline PrimitiveType HigherPrecisionType(PrimitiveType a, PrimitiveType b) {
174 // Returns a tuple where the elements are lexicographically ordered in terms
175 // of importance.
176 auto type_properties = [](PrimitiveType type) {
177 auto component_type =
178 IsComplexType(type) ? ComplexComponentType(type) : type;
179 return std::make_tuple(
180 // Prefer complex types over non-complex types.
181 IsComplexType(type),
182 // Prefer floating point types with more range over other
183 // floating-point types or non-floating point types.
184 IsFloatingPointType(component_type) ? OverflowExponent(component_type)
185 : -1,
186 // Prefer floating point types with more precision over less precise
187 // types.
188 IsFloatingPointType(component_type) ? SignificandWidth(component_type)
189 : -1,
190 // Prefer wider types over narrower types.
191 BitWidth(component_type),
192 // Prefer signed integer types over unsigned integer types.
193 IsSignedIntegralType(component_type));
194 };
195 auto a_properties = type_properties(a);
196 auto b_properties = type_properties(b);
197 if (a_properties > b_properties) {
198 return a;
199 }
200 if (b_properties > a_properties) {
201 return b;
202 }
203 CHECK_EQ(a, b);
204 return a;
205 }
206
207 // Returns true if a convert from from_type to to_type loses no precision.
CastPreservesValues(PrimitiveType from_type,PrimitiveType to_type)208 inline bool CastPreservesValues(PrimitiveType from_type,
209 PrimitiveType to_type) {
210 // * -> *
211 if (from_type == to_type) {
212 return true;
213 }
214 // PRED -> *
215 if (from_type == PRED) {
216 return true;
217 }
218 // ~PRED -> PRED is not safe because it drops almost all numbers.
219 if (to_type == PRED) {
220 return false;
221 }
222 // * -> C is safe if the components of * and C can be safely converted.
223 if (primitive_util::IsComplexType(to_type)) {
224 auto from_component_type =
225 primitive_util::IsComplexType(from_type)
226 ? primitive_util::ComplexComponentType(from_type)
227 : from_type;
228 auto to_component_type = primitive_util::ComplexComponentType(to_type);
229 return CastPreservesValues(from_component_type, to_component_type);
230 }
231 // ~C -> C is not safe because it drops imaginary components.
232 if (primitive_util::IsComplexType(from_type)) {
233 return false;
234 }
235 // F -> F is safe if the exponent and significand are preserved.
236 if (primitive_util::IsFloatingPointType(from_type) &&
237 primitive_util::IsFloatingPointType(to_type)) {
238 return primitive_util::SignificandWidth(from_type) <=
239 primitive_util::SignificandWidth(to_type) &&
240 primitive_util::ExponentWidth(from_type) <=
241 primitive_util::ExponentWidth(to_type) &&
242 primitive_util::OverflowExponent(from_type) <=
243 primitive_util::OverflowExponent(to_type);
244 }
245 // F -> I is not safe because it drops fractional numbers.
246 if (!primitive_util::IsIntegralType(from_type)) {
247 return false;
248 }
249 // An n-bit unsigned integer takes on values from [0, 2^n - 1].
250 // An n-bit signed integer takes on values from [-2^(n-1), 2^(n-1) - 1].
251 // from_bits/to_bits considers the number of non-sign bits.
252 const int from_bits = primitive_util::IsSignedIntegralType(from_type)
253 ? primitive_util::BitWidth(from_type) - 1
254 : primitive_util::BitWidth(from_type);
255 const int to_bits = primitive_util::IsSignedIntegralType(to_type)
256 ? primitive_util::BitWidth(to_type) - 1
257 : primitive_util::BitWidth(to_type);
258 // I -> F is safe if the integer can be represented exactly.
259 if (primitive_util::IsFloatingPointType(to_type)) {
260 // In both cases, we need to handle an exponent of n-1.
261 // However, the significand needed to represent signed two's complement
262 // numbers is smaller by one bit because it will only have a non-zero
263 // trailing significand field when the exponent is smaller than n-1.
264 return from_bits <= primitive_util::SignificandWidth(to_type) &&
265 primitive_util::BitWidth(from_type) - 1 <
266 primitive_util::OverflowExponent(to_type);
267 }
268 // S -> U is not safe because it drops negative numbers.
269 if (primitive_util::IsSignedIntegralType(from_type) &&
270 primitive_util::IsUnsignedIntegralType(to_type)) {
271 return false;
272 }
273 // I -> I is safe if the integer can be represented exactly; we've already
274 // ensured that signed to unsigned conversions won't happen here.
275 CHECK(primitive_util::IsIntegralType(to_type));
276 return from_bits <= to_bits;
277 }
278
279 // Returns the native type (eg, float) corresponding to the given template
280 // parameter XLA primitive type (eg, F32).
281 template <PrimitiveType>
282 struct PrimitiveTypeToNative;
283
284 // Declarations of specializations for each native type which correspond to a
285 // XLA primitive type.
286 template <>
287 struct PrimitiveTypeToNative<PRED> {
288 using type = bool;
289 };
290
291 // Unsigned integer
292 template <>
293 struct PrimitiveTypeToNative<U8> {
294 using type = uint8_t;
295 };
296
297 template <>
298 struct PrimitiveTypeToNative<U16> {
299 using type = uint16_t;
300 };
301
302 template <>
303 struct PrimitiveTypeToNative<U32> {
304 using type = uint32_t;
305 };
306
307 template <>
308 struct PrimitiveTypeToNative<U64> {
309 using type = uint64_t;
310 };
311
312 // Signed integer
313 template <>
314 struct PrimitiveTypeToNative<S8> {
315 using type = int8_t;
316 };
317
318 template <>
319 struct PrimitiveTypeToNative<S16> {
320 using type = int16_t;
321 };
322
323 template <>
324 struct PrimitiveTypeToNative<S32> {
325 using type = int32_t;
326 };
327
328 template <>
329 struct PrimitiveTypeToNative<S64> {
330 using type = int64_t;
331 };
332
333 // Floating point
334 template <>
335 struct PrimitiveTypeToNative<F32> {
336 using type = float;
337 };
338 template <>
339 struct PrimitiveTypeToNative<F64> {
340 using type = double;
341 };
342 template <>
343 struct PrimitiveTypeToNative<F16> {
344 using type = half;
345 };
346
347 template <>
348 struct PrimitiveTypeToNative<BF16> {
349 using type = bfloat16;
350 };
351
352 // Complex
353 template <>
354 struct PrimitiveTypeToNative<C64> {
355 using type = complex64;
356 };
357
358 template <>
359 struct PrimitiveTypeToNative<C128> {
360 using type = complex128;
361 };
362
363 // Returns the lower-case name of the given primitive type.
364 const std::string& LowercasePrimitiveTypeName(PrimitiveType s);
365
366 // Returns the PrimitiveType matching the given name. The given name is expected
367 // to be lower-case.
368 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
369
370 // Returns true if the given name is a primitive type string (lower-case).
371 bool IsPrimitiveTypeName(absl::string_view name);
372
373 // Returns whether `type` can be expressed as an instance of T.
374 // For example,
375 // IsCanonicalRepresentation<float>(F32) // true
376 // IsCanonicalRepresentation<xla::bfloat16>(BF16) // true
377 // IsCanonicalRepresentation<uint32_t>(S8) // true, 8 <= 32
378 // IsCanonicalRepresentation<uint8_t>(S16) // false, 16 > 8
379 template <typename T>
380 bool IsCanonicalRepresentation(PrimitiveType type) {
381 switch (type) {
382 case F16:
383 case F32:
384 case BF16:
385 case F64:
386 case C64:
387 case C128:
388 return NativeToPrimitiveType<T>() == type;
389 case S8:
390 case S16:
391 case S32:
392 case S64:
393 return std::is_integral<T>::value && std::is_signed<T>::value &&
394 ByteWidth(type) <= sizeof(T);
395 case PRED:
396 case U8:
397 case U16:
398 case U32:
399 case U64:
400 return std::is_integral<T>::value && std::is_unsigned<T>::value &&
401 ByteWidth(type) <= sizeof(T);
402 case TUPLE:
403 case OPAQUE_TYPE:
404 case TOKEN:
405 case PRIMITIVE_TYPE_INVALID:
406 case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
407 case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
408 return false;
409 }
410 }
411
412 } // namespace primitive_util
413 } // namespace xla
414
415 #endif // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
416