1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with XLA primitive types.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20
21 #include <type_traits>
22
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace primitive_util {
31
32 // Returns the count of significand (mantissa) bits for float datatypes.
33 // For non-float datatypes, results in a LOG(FATAL).
34 int SignificandWidth(PrimitiveType type);
35
36 // Returns the count of exponent bits for float datatypes.
37 // For non-float datatypes, results in a LOG(FATAL).
38 int ExponentWidth(PrimitiveType type);
39
40 // Returns the exponent of the smallest number which cannot be represented.
41 // For non-float datatypes, results in a LOG(FATAL).
42 int OverflowExponent(PrimitiveType type);
43
44 // Returns the XLA primitive type (eg, F32) corresponding to the given
45 // template parameter native type (eg, float).
46 template <typename NativeT>
NativeToPrimitiveType()47 PrimitiveType NativeToPrimitiveType() {
48 // Make the expression depend on the template parameter NativeT so
49 // that this compile-time error only appears if this function is
50 // instantiated with some concrete type that is not specialized
51 // below.
52 static_assert(!std::is_same<NativeT, NativeT>::value,
53 "Cannot map native type to primitive type.");
54 return PRIMITIVE_TYPE_INVALID;
55 }
56
57 // Declarations of specializations for each native type which correspond to a
58 // XLA primitive type. As an optimization, these are declared inline in the
59 // header.
60 template <>
61 inline PrimitiveType NativeToPrimitiveType<bool>() {
62 return PRED;
63 }
64
65 // Unsigned integer
66 template <>
67 inline PrimitiveType NativeToPrimitiveType<uint8>() {
68 return U8;
69 }
70
71 template <>
72 inline PrimitiveType NativeToPrimitiveType<uint16>() {
73 return U16;
74 }
75
76 template <>
77 inline PrimitiveType NativeToPrimitiveType<uint32>() {
78 return U32;
79 }
80
81 template <>
82 inline PrimitiveType NativeToPrimitiveType<uint64>() {
83 return U64;
84 }
85
86 // Signed integer
87 template <>
88 inline PrimitiveType NativeToPrimitiveType<int8>() {
89 return S8;
90 }
91
92 template <>
93 inline PrimitiveType NativeToPrimitiveType<int16>() {
94 return S16;
95 }
96
97 template <>
98 inline PrimitiveType NativeToPrimitiveType<int32>() {
99 return S32;
100 }
101
102 template <>
103 inline PrimitiveType NativeToPrimitiveType<int64>() {
104 return S64;
105 }
106
107 // Floating point
108 template <>
109 inline PrimitiveType NativeToPrimitiveType<float>() {
110 return F32;
111 }
112
113 template <>
114 inline PrimitiveType NativeToPrimitiveType<double>() {
115 return F64;
116 }
117
118 template <>
119 inline PrimitiveType NativeToPrimitiveType<half>() {
120 return F16;
121 }
122
123 template <>
124 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
125 return BF16;
126 }
127
128 // Complex
129 template <>
130 inline PrimitiveType NativeToPrimitiveType<complex64>() {
131 return C64;
132 }
133
134 template <>
135 inline PrimitiveType NativeToPrimitiveType<complex128>() {
136 return C128;
137 }
138
139 bool IsFloatingPointType(PrimitiveType type);
140
141 bool IsComplexType(PrimitiveType type);
142
143 bool IsSignedIntegralType(PrimitiveType type);
144
145 bool IsUnsignedIntegralType(PrimitiveType type);
146
147 bool IsIntegralType(PrimitiveType type);
148
149 // Returns true if values of the given primitive type are held in array shapes.
150 bool IsArrayType(PrimitiveType primitive_type);
151
152 // Returns the number of bits in the representation for a given type.
153 int BitWidth(PrimitiveType type);
154
155 // Returns the number of bytes in the representation for a given type.
156 int ByteWidth(PrimitiveType type);
157
158 PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth);
159
160 PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth);
161
162 // Returns the real, imag component type underlying the given complex type.
163 // LOG(FATAL)'s if complex_type is not complex.
164 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
165
166 // Returns the higher-precision element type if a and b are both floating
167 // point types; otherwise, checks that they have the same element type
168 // and returns it.
HigherPrecisionType(PrimitiveType a,PrimitiveType b)169 inline PrimitiveType HigherPrecisionType(PrimitiveType a, PrimitiveType b) {
170 // Returns a tuple where the elements are lexicographically ordered in terms
171 // of importance.
172 auto type_properties = [](PrimitiveType type) {
173 return std::make_tuple(
174 // Prefer floating point types with more range over other
175 // floating-point types or non-floating point types.
176 IsFloatingPointType(type) ? OverflowExponent(type) : -1,
177 // Prefer floating point types with more precision over less precise
178 // types.
179 IsFloatingPointType(type) ? SignificandWidth(type) : -1,
180 // Prefer wider types over narrower types.
181 BitWidth(type),
182 // Prefer signed integer types over unsigned integer types.
183 IsSignedIntegralType(type));
184 };
185 auto a_properties = type_properties(a);
186 auto b_properties = type_properties(b);
187 if (a_properties > b_properties) {
188 return a;
189 }
190 if (b_properties > a_properties) {
191 return b;
192 }
193 CHECK_EQ(a, b);
194 return a;
195 }
196
197 // Returns true if a convert from from_type to to_type looses no precision.
CastPreservesValues(PrimitiveType from_type,PrimitiveType to_type)198 inline bool CastPreservesValues(PrimitiveType from_type,
199 PrimitiveType to_type) {
200 if (from_type == to_type) {
201 return true;
202 }
203 switch (to_type) {
204 case C128:
205 if (from_type == F64) {
206 return true;
207 }
208 ABSL_FALLTHROUGH_INTENDED;
209 case F64:
210 if (from_type == S32 || from_type == U32 || from_type == F32) {
211 return true;
212 }
213 ABSL_FALLTHROUGH_INTENDED;
214 case C64:
215 if (from_type == F32) {
216 return true;
217 }
218 ABSL_FALLTHROUGH_INTENDED;
219 case F32:
220 if (from_type == F16 || from_type == BF16 || from_type == S16 ||
221 from_type == U16) {
222 return true;
223 }
224 ABSL_FALLTHROUGH_INTENDED;
225 case F16:
226 case BF16:
227 return from_type == U8 || from_type == S8 || from_type == PRED;
228 case S64:
229 if (from_type == S32 || from_type == U32) {
230 return true;
231 }
232 ABSL_FALLTHROUGH_INTENDED;
233 case S32:
234 if (from_type == S16 || from_type == U16) {
235 return true;
236 }
237 ABSL_FALLTHROUGH_INTENDED;
238 case S16:
239 if (from_type == S8 || from_type == U8) {
240 return true;
241 }
242 ABSL_FALLTHROUGH_INTENDED;
243 case S8:
244 if (from_type == PRED) {
245 return true;
246 }
247 ABSL_FALLTHROUGH_INTENDED;
248 case PRED:
249 return false;
250 case U64:
251 if (from_type == U32) {
252 return true;
253 }
254 ABSL_FALLTHROUGH_INTENDED;
255 case U32:
256 if (from_type == U16) {
257 return true;
258 }
259 ABSL_FALLTHROUGH_INTENDED;
260 case U16:
261 if (from_type == U8) {
262 return true;
263 }
264 ABSL_FALLTHROUGH_INTENDED;
265 case U8:
266 return from_type == PRED;
267 default:
268 return false;
269 }
270 }
271
272 // Returns the native type (eg, float) corresponding to the given template
273 // parameter XLA primitive type (eg, F32).
274 template <PrimitiveType>
275 struct PrimitiveTypeToNative;
276
277 // Declarations of specializations for each native type which correspond to a
278 // XLA primitive type.
279 template <>
280 struct PrimitiveTypeToNative<PRED> {
281 using type = bool;
282 };
283
284 // Unsigned integer
285 template <>
286 struct PrimitiveTypeToNative<U8> {
287 using type = uint8;
288 };
289
290 template <>
291 struct PrimitiveTypeToNative<U16> {
292 using type = uint16;
293 };
294
295 template <>
296 struct PrimitiveTypeToNative<U32> {
297 using type = uint32;
298 };
299
300 template <>
301 struct PrimitiveTypeToNative<U64> {
302 using type = uint64;
303 };
304
305 // Signed integer
306 template <>
307 struct PrimitiveTypeToNative<S8> {
308 using type = int8;
309 };
310
311 template <>
312 struct PrimitiveTypeToNative<S16> {
313 using type = int16;
314 };
315
316 template <>
317 struct PrimitiveTypeToNative<S32> {
318 using type = int32;
319 };
320
321 template <>
322 struct PrimitiveTypeToNative<S64> {
323 using type = int64;
324 };
325
326 // Floating point
327 template <>
328 struct PrimitiveTypeToNative<F32> {
329 using type = float;
330 };
331 template <>
332 struct PrimitiveTypeToNative<F64> {
333 using type = double;
334 };
335 template <>
336 struct PrimitiveTypeToNative<F16> {
337 using type = half;
338 };
339
340 template <>
341 struct PrimitiveTypeToNative<BF16> {
342 using type = bfloat16;
343 };
344
345 // Complex
346 template <>
347 struct PrimitiveTypeToNative<C64> {
348 using type = complex64;
349 };
350
351 template <>
352 struct PrimitiveTypeToNative<C128> {
353 using type = complex128;
354 };
355
356 // Returns the lower-case name of the given primitive type.
357 const string& LowercasePrimitiveTypeName(PrimitiveType s);
358
359 // Returns the PrimitiveType matching the given name. The given name is expected
360 // to be lower-case.
361 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
362
363 // Returns true if the given name is a primitive type string (lower-case).
364 bool IsPrimitiveTypeName(absl::string_view name);
365
366 } // namespace primitive_util
367 } // namespace xla
368
369 #endif // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
370