• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with XLA primitive types.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20 
21 #include <type_traits>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace primitive_util {
31 
32 // Returns the count of significand (mantissa) bits for float datatypes.
33 // For non-float datatypes, results in a LOG(FATAL).
34 int SignificandWidth(PrimitiveType type);
35 
36 // Returns the count of exponent bits for float datatypes.
37 // For non-float datatypes, results in a LOG(FATAL).
38 int ExponentWidth(PrimitiveType type);
39 
40 // Returns the exponent of the smallest number which cannot be represented.
41 // For non-float datatypes, results in a LOG(FATAL).
42 int OverflowExponent(PrimitiveType type);
43 
44 // Returns the XLA primitive type (eg, F32) corresponding to the given
45 // template parameter native type (eg, float).
46 template <typename NativeT>
NativeToPrimitiveType()47 PrimitiveType NativeToPrimitiveType() {
48   // Make the expression depend on the template parameter NativeT so
49   // that this compile-time error only appears if this function is
50   // instantiated with some concrete type that is not specialized
51   // below.
52   static_assert(!std::is_same<NativeT, NativeT>::value,
53                 "Cannot map native type to primitive type.");
54   return PRIMITIVE_TYPE_INVALID;
55 }
56 
57 // Declarations of specializations for each native type which correspond to a
58 // XLA primitive type.  As an optimization, these are declared inline in the
59 // header.
60 template <>
61 inline PrimitiveType NativeToPrimitiveType<bool>() {
62   return PRED;
63 }
64 
65 // Unsigned integer
66 template <>
67 inline PrimitiveType NativeToPrimitiveType<uint8>() {
68   return U8;
69 }
70 
71 template <>
72 inline PrimitiveType NativeToPrimitiveType<uint16>() {
73   return U16;
74 }
75 
76 template <>
77 inline PrimitiveType NativeToPrimitiveType<uint32>() {
78   return U32;
79 }
80 
81 template <>
82 inline PrimitiveType NativeToPrimitiveType<uint64>() {
83   return U64;
84 }
85 
86 // Signed integer
87 template <>
88 inline PrimitiveType NativeToPrimitiveType<int8>() {
89   return S8;
90 }
91 
92 template <>
93 inline PrimitiveType NativeToPrimitiveType<int16>() {
94   return S16;
95 }
96 
97 template <>
98 inline PrimitiveType NativeToPrimitiveType<int32>() {
99   return S32;
100 }
101 
102 template <>
103 inline PrimitiveType NativeToPrimitiveType<int64>() {
104   return S64;
105 }
106 
107 // Floating point
108 template <>
109 inline PrimitiveType NativeToPrimitiveType<float>() {
110   return F32;
111 }
112 
113 template <>
114 inline PrimitiveType NativeToPrimitiveType<double>() {
115   return F64;
116 }
117 
118 template <>
119 inline PrimitiveType NativeToPrimitiveType<half>() {
120   return F16;
121 }
122 
123 template <>
124 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
125   return BF16;
126 }
127 
128 // Complex
129 template <>
130 inline PrimitiveType NativeToPrimitiveType<complex64>() {
131   return C64;
132 }
133 
134 template <>
135 inline PrimitiveType NativeToPrimitiveType<complex128>() {
136   return C128;
137 }
138 
139 bool IsFloatingPointType(PrimitiveType type);
140 
141 bool IsComplexType(PrimitiveType type);
142 
143 bool IsSignedIntegralType(PrimitiveType type);
144 
145 bool IsUnsignedIntegralType(PrimitiveType type);
146 
147 bool IsIntegralType(PrimitiveType type);
148 
149 // Returns true if values of the given primitive type are held in array shapes.
150 bool IsArrayType(PrimitiveType primitive_type);
151 
152 // Returns the number of bits in the representation for a given type.
153 int BitWidth(PrimitiveType type);
154 
155 // Returns the number of bytes in the representation for a given type.
156 int ByteWidth(PrimitiveType type);
157 
158 PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth);
159 
160 PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth);
161 
162 // Returns the real, imag component type underlying the given complex type.
163 // LOG(FATAL)'s if complex_type is not complex.
164 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
165 
166 // Returns the higher-precision element type if a and b are both floating
167 // point types; otherwise, checks that they have the same element type
168 // and returns it.
HigherPrecisionType(PrimitiveType a,PrimitiveType b)169 inline PrimitiveType HigherPrecisionType(PrimitiveType a, PrimitiveType b) {
170   // Returns a tuple where the elements are lexicographically ordered in terms
171   // of importance.
172   auto type_properties = [](PrimitiveType type) {
173     return std::make_tuple(
174         // Prefer floating point types with more range over other
175         // floating-point types or non-floating point types.
176         IsFloatingPointType(type) ? OverflowExponent(type) : -1,
177         // Prefer floating point types with more precision over less precise
178         // types.
179         IsFloatingPointType(type) ? SignificandWidth(type) : -1,
180         // Prefer wider types over narrower types.
181         BitWidth(type),
182         // Prefer signed integer types over unsigned integer types.
183         IsSignedIntegralType(type));
184   };
185   auto a_properties = type_properties(a);
186   auto b_properties = type_properties(b);
187   if (a_properties > b_properties) {
188     return a;
189   }
190   if (b_properties > a_properties) {
191     return b;
192   }
193   CHECK_EQ(a, b);
194   return a;
195 }
196 
197 // Returns true if a convert from from_type to to_type looses no precision.
CastPreservesValues(PrimitiveType from_type,PrimitiveType to_type)198 inline bool CastPreservesValues(PrimitiveType from_type,
199                                 PrimitiveType to_type) {
200   if (from_type == to_type) {
201     return true;
202   }
203   switch (to_type) {
204     case C128:
205       if (from_type == F64) {
206         return true;
207       }
208       ABSL_FALLTHROUGH_INTENDED;
209     case F64:
210       if (from_type == S32 || from_type == U32 || from_type == F32) {
211         return true;
212       }
213       ABSL_FALLTHROUGH_INTENDED;
214     case C64:
215       if (from_type == F32) {
216         return true;
217       }
218       ABSL_FALLTHROUGH_INTENDED;
219     case F32:
220       if (from_type == F16 || from_type == BF16 || from_type == S16 ||
221           from_type == U16) {
222         return true;
223       }
224       ABSL_FALLTHROUGH_INTENDED;
225     case F16:
226     case BF16:
227       return from_type == U8 || from_type == S8 || from_type == PRED;
228     case S64:
229       if (from_type == S32 || from_type == U32) {
230         return true;
231       }
232       ABSL_FALLTHROUGH_INTENDED;
233     case S32:
234       if (from_type == S16 || from_type == U16) {
235         return true;
236       }
237       ABSL_FALLTHROUGH_INTENDED;
238     case S16:
239       if (from_type == S8 || from_type == U8) {
240         return true;
241       }
242       ABSL_FALLTHROUGH_INTENDED;
243     case S8:
244       if (from_type == PRED) {
245         return true;
246       }
247       ABSL_FALLTHROUGH_INTENDED;
248     case PRED:
249       return false;
250     case U64:
251       if (from_type == U32) {
252         return true;
253       }
254       ABSL_FALLTHROUGH_INTENDED;
255     case U32:
256       if (from_type == U16) {
257         return true;
258       }
259       ABSL_FALLTHROUGH_INTENDED;
260     case U16:
261       if (from_type == U8) {
262         return true;
263       }
264       ABSL_FALLTHROUGH_INTENDED;
265     case U8:
266       return from_type == PRED;
267     default:
268       return false;
269   }
270 }
271 
272 // Returns the native type (eg, float) corresponding to the given template
273 // parameter XLA primitive type (eg, F32).
274 template <PrimitiveType>
275 struct PrimitiveTypeToNative;
276 
277 // Declarations of specializations for each native type which correspond to a
278 // XLA primitive type.
279 template <>
280 struct PrimitiveTypeToNative<PRED> {
281   using type = bool;
282 };
283 
284 // Unsigned integer
285 template <>
286 struct PrimitiveTypeToNative<U8> {
287   using type = uint8;
288 };
289 
290 template <>
291 struct PrimitiveTypeToNative<U16> {
292   using type = uint16;
293 };
294 
295 template <>
296 struct PrimitiveTypeToNative<U32> {
297   using type = uint32;
298 };
299 
300 template <>
301 struct PrimitiveTypeToNative<U64> {
302   using type = uint64;
303 };
304 
305 // Signed integer
306 template <>
307 struct PrimitiveTypeToNative<S8> {
308   using type = int8;
309 };
310 
311 template <>
312 struct PrimitiveTypeToNative<S16> {
313   using type = int16;
314 };
315 
316 template <>
317 struct PrimitiveTypeToNative<S32> {
318   using type = int32;
319 };
320 
321 template <>
322 struct PrimitiveTypeToNative<S64> {
323   using type = int64;
324 };
325 
326 // Floating point
327 template <>
328 struct PrimitiveTypeToNative<F32> {
329   using type = float;
330 };
331 template <>
332 struct PrimitiveTypeToNative<F64> {
333   using type = double;
334 };
335 template <>
336 struct PrimitiveTypeToNative<F16> {
337   using type = half;
338 };
339 
340 template <>
341 struct PrimitiveTypeToNative<BF16> {
342   using type = bfloat16;
343 };
344 
345 // Complex
346 template <>
347 struct PrimitiveTypeToNative<C64> {
348   using type = complex64;
349 };
350 
351 template <>
352 struct PrimitiveTypeToNative<C128> {
353   using type = complex128;
354 };
355 
356 // Returns the lower-case name of the given primitive type.
357 const string& LowercasePrimitiveTypeName(PrimitiveType s);
358 
359 // Returns the PrimitiveType matching the given name. The given name is expected
360 // to be lower-case.
361 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
362 
363 // Returns true if the given name is a primitive type string (lower-case).
364 bool IsPrimitiveTypeName(absl::string_view name);
365 
366 }  // namespace primitive_util
367 }  // namespace xla
368 
369 #endif  // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
370