1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with XLA primitive types.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20
21 #include <type_traits>
22
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace primitive_util {
31
32 // Returns the count of significand (mantissa) bits for float datatypes.
33 // For non-float datatypes, results in a LOG(FATAL).
34 int SignificandWidth(PrimitiveType type);
35
36 // The number of exponent bits in a BF16 value.
37 const int kBFloat16ExponentBits = 8;
38
39 // The number of mantissa bits in a BF16 value. There is an implicit leading
40 // 1, so there is an implicit additional bit of precision.
41 const int kBFloat16MantissaBits = 7;
42
43 // Returns the XLA primitive type (eg, F32) corresponding to the given
44 // template parameter native type (eg, float).
45 template <typename NativeT>
NativeToPrimitiveType()46 PrimitiveType NativeToPrimitiveType() {
47 // Make the expression depend on the template parameter NativeT so
48 // that this compile-time error only apperas if this function is
49 // instantiated with some concrete type that is not specialized
50 // below.
51 static_assert(!std::is_same<NativeT, NativeT>::value,
52 "Cannot map native type to primitive type.");
53 return PRIMITIVE_TYPE_INVALID;
54 }
55
56 // Declarations of specializations for each native type which correspond to a
57 // XLA primitive type. As an optimization, these are declared inline in the
58 // header.
59 template <>
60 inline PrimitiveType NativeToPrimitiveType<bool>() {
61 return PRED;
62 }
63
64 // Unsigned integer
65 template <>
66 inline PrimitiveType NativeToPrimitiveType<uint8>() {
67 return U8;
68 }
69
70 template <>
71 inline PrimitiveType NativeToPrimitiveType<uint16>() {
72 return U16;
73 }
74
75 template <>
76 inline PrimitiveType NativeToPrimitiveType<uint32>() {
77 return U32;
78 }
79
80 template <>
81 inline PrimitiveType NativeToPrimitiveType<uint64>() {
82 return U64;
83 }
84
85 // Signed integer
86 template <>
87 inline PrimitiveType NativeToPrimitiveType<int8>() {
88 return S8;
89 }
90
91 template <>
92 inline PrimitiveType NativeToPrimitiveType<int16>() {
93 return S16;
94 }
95
96 template <>
97 inline PrimitiveType NativeToPrimitiveType<int32>() {
98 return S32;
99 }
100
101 template <>
102 inline PrimitiveType NativeToPrimitiveType<int64>() {
103 return S64;
104 }
105
106 // Floating point
107 template <>
108 inline PrimitiveType NativeToPrimitiveType<float>() {
109 return F32;
110 }
111
112 template <>
113 inline PrimitiveType NativeToPrimitiveType<double>() {
114 return F64;
115 }
116
117 template <>
118 inline PrimitiveType NativeToPrimitiveType<half>() {
119 return F16;
120 }
121
122 template <>
123 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
124 return BF16;
125 }
126
127 // Complex
128 template <>
129 inline PrimitiveType NativeToPrimitiveType<complex64>() {
130 return C64;
131 }
132
133 template <>
134 inline PrimitiveType NativeToPrimitiveType<complex128>() {
135 return C128;
136 }
137
138 bool IsFloatingPointType(PrimitiveType type);
139
140 bool IsComplexType(PrimitiveType type);
141
142 bool IsSignedIntegralType(PrimitiveType type);
143
144 bool IsUnsignedIntegralType(PrimitiveType type);
145
146 bool IsIntegralType(PrimitiveType type);
147
148 // Returns true if values of the given primitive type are held in array shapes.
149 bool IsArrayType(PrimitiveType primitive_type);
150
151 // Returns the number of bits in the representation for a given type.
152 int BitWidth(PrimitiveType type);
153
154 PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
155
156 // Returns the real, imag component type underlying the given complex type.
157 // LOG(FATAL)'s if complex_type is not complex.
158 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
159
160 // Returns the native type (eg, float) corresponding to the given template
161 // parameter XLA primitive type (eg, F32).
162 template <PrimitiveType>
163 struct PrimitiveTypeToNative;
164
165 // Declarations of specializations for each native type which correspond to a
166 // XLA primitive type.
167 template <>
168 struct PrimitiveTypeToNative<PRED> {
169 using type = bool;
170 };
171
172 // Unsigned integer
173 template <>
174 struct PrimitiveTypeToNative<U8> {
175 using type = uint8;
176 };
177
178 template <>
179 struct PrimitiveTypeToNative<U16> {
180 using type = uint16;
181 };
182
183 template <>
184 struct PrimitiveTypeToNative<U32> {
185 using type = uint32;
186 };
187
188 template <>
189 struct PrimitiveTypeToNative<U64> {
190 using type = uint64;
191 };
192
193 // Signed integer
194 template <>
195 struct PrimitiveTypeToNative<S8> {
196 using type = int8;
197 };
198
199 template <>
200 struct PrimitiveTypeToNative<S16> {
201 using type = int16;
202 };
203
204 template <>
205 struct PrimitiveTypeToNative<S32> {
206 using type = int32;
207 };
208
209 template <>
210 struct PrimitiveTypeToNative<S64> {
211 using type = int64;
212 };
213
214 // Floating point
215 template <>
216 struct PrimitiveTypeToNative<F32> {
217 using type = float;
218 };
219 template <>
220 struct PrimitiveTypeToNative<F64> {
221 using type = double;
222 };
223 template <>
224 struct PrimitiveTypeToNative<F16> {
225 using type = half;
226 };
227
228 template <>
229 struct PrimitiveTypeToNative<BF16> {
230 using type = bfloat16;
231 };
232
233 // Complex
234 template <>
235 struct PrimitiveTypeToNative<C64> {
236 using type = complex64;
237 };
238
239 template <>
240 struct PrimitiveTypeToNative<C128> {
241 using type = complex128;
242 };
243
244 // Returns the lower-case name of the given primitive type.
245 const string& LowercasePrimitiveTypeName(PrimitiveType s);
246
247 // Returns the PrimitiveType matching the given name. The given name is expected
248 // to be lower-case.
249 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
250
251 // Returns true if the given name is a primitive type string (lower-case).
252 bool IsPrimitiveTypeName(absl::string_view name);
253
254 } // namespace primitive_util
255 } // namespace xla
256
257 #endif // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
258