• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with XLA primitive types.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20 
21 #include <string>
22 #include <type_traits>
23 
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 namespace primitive_util {
32 
33 // Returns the count of significand (mantissa) bits for float datatypes.
34 // For non-float datatypes, results in a LOG(FATAL).
35 int SignificandWidth(PrimitiveType type);
36 
37 // Returns the count of exponent bits for float datatypes.
38 // For non-float datatypes, results in a LOG(FATAL).
39 int ExponentWidth(PrimitiveType type);
40 
41 // Returns the exponent of the smallest number which cannot be represented.
42 // For non-float datatypes, results in a LOG(FATAL).
43 int OverflowExponent(PrimitiveType type);
44 
45 // Returns the XLA primitive type (eg, F32) corresponding to the given
46 // template parameter native type (eg, float).
47 template <typename NativeT>
NativeToPrimitiveType()48 PrimitiveType NativeToPrimitiveType() {
49   // Make the expression depend on the template parameter NativeT so
50   // that this compile-time error only appears if this function is
51   // instantiated with some concrete type that is not specialized
52   // below.
53   static_assert(!std::is_same<NativeT, NativeT>::value,
54                 "Cannot map native type to primitive type.");
55   return PRIMITIVE_TYPE_INVALID;
56 }
57 
58 // Declarations of specializations for each native type which correspond to a
59 // XLA primitive type.  As an optimization, these are declared inline in the
60 // header.
61 template <>
62 inline PrimitiveType NativeToPrimitiveType<bool>() {
63   return PRED;
64 }
65 
66 // Unsigned integer
67 template <>
68 inline PrimitiveType NativeToPrimitiveType<uint8_t>() {
69   return U8;
70 }
71 
72 template <>
73 inline PrimitiveType NativeToPrimitiveType<uint16_t>() {
74   return U16;
75 }
76 
77 template <>
78 inline PrimitiveType NativeToPrimitiveType<uint32_t>() {
79   return U32;
80 }
81 
82 template <>
83 inline PrimitiveType NativeToPrimitiveType<uint64_t>() {
84   return U64;
85 }
86 
87 // Signed integer
88 template <>
89 inline PrimitiveType NativeToPrimitiveType<int8_t>() {
90   return S8;
91 }
92 
93 template <>
94 inline PrimitiveType NativeToPrimitiveType<int16_t>() {
95   return S16;
96 }
97 
98 template <>
99 inline PrimitiveType NativeToPrimitiveType<int32_t>() {
100   return S32;
101 }
102 
103 template <>
104 inline PrimitiveType NativeToPrimitiveType<int64_t>() {
105   return S64;
106 }
107 
108 // Floating point
109 template <>
110 inline PrimitiveType NativeToPrimitiveType<float>() {
111   return F32;
112 }
113 
114 template <>
115 inline PrimitiveType NativeToPrimitiveType<double>() {
116   return F64;
117 }
118 
119 template <>
120 inline PrimitiveType NativeToPrimitiveType<half>() {
121   return F16;
122 }
123 
124 template <>
125 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
126   return BF16;
127 }
128 
129 // Complex
130 template <>
131 inline PrimitiveType NativeToPrimitiveType<complex64>() {
132   return C64;
133 }
134 
135 template <>
136 inline PrimitiveType NativeToPrimitiveType<complex128>() {
137   return C128;
138 }
139 
140 bool IsFloatingPointType(PrimitiveType type);
141 
142 bool IsComplexType(PrimitiveType type);
143 
144 bool IsSignedIntegralType(PrimitiveType type);
145 
146 bool IsUnsignedIntegralType(PrimitiveType type);
147 
148 bool IsIntegralType(PrimitiveType type);
149 
150 // Returns true if values of the given primitive type are held in array shapes.
IsArrayType(PrimitiveType primitive_type)151 inline constexpr bool IsArrayType(PrimitiveType primitive_type) {
152   return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
153          primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
154 }
155 
156 // Returns the number of bits in the representation for a given type.
157 int BitWidth(PrimitiveType type);
158 
159 // Returns the number of bytes in the representation for a given type.
160 int ByteWidth(PrimitiveType type);
161 
162 PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth);
163 
164 PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth);
165 
166 // Returns the real, imag component type underlying the given complex type.
167 // LOG(FATAL)'s if complex_type is not complex.
168 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
169 
170 // Returns the higher-precision element type if a and b are both floating
171 // point types; otherwise, checks that they have the same element type
172 // and returns it.
HigherPrecisionType(PrimitiveType a,PrimitiveType b)173 inline PrimitiveType HigherPrecisionType(PrimitiveType a, PrimitiveType b) {
174   // Returns a tuple where the elements are lexicographically ordered in terms
175   // of importance.
176   auto type_properties = [](PrimitiveType type) {
177     auto component_type =
178         IsComplexType(type) ? ComplexComponentType(type) : type;
179     return std::make_tuple(
180         // Prefer complex types over non-complex types.
181         IsComplexType(type),
182         // Prefer floating point types with more range over other
183         // floating-point types or non-floating point types.
184         IsFloatingPointType(component_type) ? OverflowExponent(component_type)
185                                             : -1,
186         // Prefer floating point types with more precision over less precise
187         // types.
188         IsFloatingPointType(component_type) ? SignificandWidth(component_type)
189                                             : -1,
190         // Prefer wider types over narrower types.
191         BitWidth(component_type),
192         // Prefer signed integer types over unsigned integer types.
193         IsSignedIntegralType(component_type));
194   };
195   auto a_properties = type_properties(a);
196   auto b_properties = type_properties(b);
197   if (a_properties > b_properties) {
198     return a;
199   }
200   if (b_properties > a_properties) {
201     return b;
202   }
203   CHECK_EQ(a, b);
204   return a;
205 }
206 
207 // Returns true if a convert from from_type to to_type loses no precision.
CastPreservesValues(PrimitiveType from_type,PrimitiveType to_type)208 inline bool CastPreservesValues(PrimitiveType from_type,
209                                 PrimitiveType to_type) {
210   // * -> *
211   if (from_type == to_type) {
212     return true;
213   }
214   // PRED -> *
215   if (from_type == PRED) {
216     return true;
217   }
218   // ~PRED -> PRED is not safe because it drops almost all numbers.
219   if (to_type == PRED) {
220     return false;
221   }
222   // * -> C is safe if the components of * and C can be safely converted.
223   if (primitive_util::IsComplexType(to_type)) {
224     auto from_component_type =
225         primitive_util::IsComplexType(from_type)
226             ? primitive_util::ComplexComponentType(from_type)
227             : from_type;
228     auto to_component_type = primitive_util::ComplexComponentType(to_type);
229     return CastPreservesValues(from_component_type, to_component_type);
230   }
231   // ~C -> C is not safe because it drops imaginary components.
232   if (primitive_util::IsComplexType(from_type)) {
233     return false;
234   }
235   // F -> F is safe if the exponent and significand are preserved.
236   if (primitive_util::IsFloatingPointType(from_type) &&
237       primitive_util::IsFloatingPointType(to_type)) {
238     return primitive_util::SignificandWidth(from_type) <=
239                primitive_util::SignificandWidth(to_type) &&
240            primitive_util::ExponentWidth(from_type) <=
241                primitive_util::ExponentWidth(to_type) &&
242            primitive_util::OverflowExponent(from_type) <=
243                primitive_util::OverflowExponent(to_type);
244   }
245   // F -> I is not safe because it drops fractional numbers.
246   if (!primitive_util::IsIntegralType(from_type)) {
247     return false;
248   }
249   // An n-bit unsigned integer takes on values from [0, 2^n - 1].
250   // An n-bit signed integer takes on values from [-2^(n-1), 2^(n-1) - 1].
251   // from_bits/to_bits considers the number of non-sign bits.
252   const int from_bits = primitive_util::IsSignedIntegralType(from_type)
253                             ? primitive_util::BitWidth(from_type) - 1
254                             : primitive_util::BitWidth(from_type);
255   const int to_bits = primitive_util::IsSignedIntegralType(to_type)
256                           ? primitive_util::BitWidth(to_type) - 1
257                           : primitive_util::BitWidth(to_type);
258   // I -> F is safe if the integer can be represented exactly.
259   if (primitive_util::IsFloatingPointType(to_type)) {
260     // In both cases, we need to handle an exponent of n-1.
261     // However, the significand needed to represent signed two's complement
262     // numbers is smaller by one bit because it will only have a non-zero
263     // trailing significand field when the exponent is smaller than n-1.
264     return from_bits <= primitive_util::SignificandWidth(to_type) &&
265            primitive_util::BitWidth(from_type) - 1 <
266                primitive_util::OverflowExponent(to_type);
267   }
268   // S -> U is not safe because it drops negative numbers.
269   if (primitive_util::IsSignedIntegralType(from_type) &&
270       primitive_util::IsUnsignedIntegralType(to_type)) {
271     return false;
272   }
273   // I -> I is safe if the integer can be represented exactly; we've already
274   // ensured that signed to unsigned conversions won't happen here.
275   CHECK(primitive_util::IsIntegralType(to_type));
276   return from_bits <= to_bits;
277 }
278 
279 // Returns the native type (eg, float) corresponding to the given template
280 // parameter XLA primitive type (eg, F32).
281 template <PrimitiveType>
282 struct PrimitiveTypeToNative;
283 
284 // Declarations of specializations for each native type which correspond to a
285 // XLA primitive type.
286 template <>
287 struct PrimitiveTypeToNative<PRED> {
288   using type = bool;
289 };
290 
291 // Unsigned integer
292 template <>
293 struct PrimitiveTypeToNative<U8> {
294   using type = uint8_t;
295 };
296 
297 template <>
298 struct PrimitiveTypeToNative<U16> {
299   using type = uint16_t;
300 };
301 
302 template <>
303 struct PrimitiveTypeToNative<U32> {
304   using type = uint32_t;
305 };
306 
307 template <>
308 struct PrimitiveTypeToNative<U64> {
309   using type = uint64_t;
310 };
311 
312 // Signed integer
313 template <>
314 struct PrimitiveTypeToNative<S8> {
315   using type = int8_t;
316 };
317 
318 template <>
319 struct PrimitiveTypeToNative<S16> {
320   using type = int16_t;
321 };
322 
323 template <>
324 struct PrimitiveTypeToNative<S32> {
325   using type = int32_t;
326 };
327 
328 template <>
329 struct PrimitiveTypeToNative<S64> {
330   using type = int64_t;
331 };
332 
333 // Floating point
334 template <>
335 struct PrimitiveTypeToNative<F32> {
336   using type = float;
337 };
338 template <>
339 struct PrimitiveTypeToNative<F64> {
340   using type = double;
341 };
342 template <>
343 struct PrimitiveTypeToNative<F16> {
344   using type = half;
345 };
346 
347 template <>
348 struct PrimitiveTypeToNative<BF16> {
349   using type = bfloat16;
350 };
351 
352 // Complex
353 template <>
354 struct PrimitiveTypeToNative<C64> {
355   using type = complex64;
356 };
357 
358 template <>
359 struct PrimitiveTypeToNative<C128> {
360   using type = complex128;
361 };
362 
363 // Returns the lower-case name of the given primitive type.
364 const std::string& LowercasePrimitiveTypeName(PrimitiveType s);
365 
366 // Returns the PrimitiveType matching the given name. The given name is expected
367 // to be lower-case.
368 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
369 
370 // Returns true if the given name is a primitive type string (lower-case).
371 bool IsPrimitiveTypeName(absl::string_view name);
372 
373 // Returns whether `type` can be expressed as an instance of T.
374 // For example,
375 //  IsCanonicalRepresentation<float>(F32)          // true
376 //  IsCanonicalRepresentation<xla::bfloat16>(BF16) // true
377 //  IsCanonicalRepresentation<uint32_t>(S8)        // true, 8 <= 32
378 //  IsCanonicalRepresentation<uint8_t>(S16)        // false, 16 > 8
379 template <typename T>
380 bool IsCanonicalRepresentation(PrimitiveType type) {
381   switch (type) {
382     case F16:
383     case F32:
384     case BF16:
385     case F64:
386     case C64:
387     case C128:
388       return NativeToPrimitiveType<T>() == type;
389     case S8:
390     case S16:
391     case S32:
392     case S64:
393       return std::is_integral<T>::value && std::is_signed<T>::value &&
394              ByteWidth(type) <= sizeof(T);
395     case PRED:
396     case U8:
397     case U16:
398     case U32:
399     case U64:
400       return std::is_integral<T>::value && std::is_unsigned<T>::value &&
401              ByteWidth(type) <= sizeof(T);
402     case TUPLE:
403     case OPAQUE_TYPE:
404     case TOKEN:
405     case PRIMITIVE_TYPE_INVALID:
406     case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_:
407     case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_:
408       return false;
409   }
410 }
411 
412 }  // namespace primitive_util
413 }  // namespace xla
414 
415 #endif  // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
416