1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utilities for dealing with XLA primitive types.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20
21 #include <type_traits>
22
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace primitive_util {
31
32 // Returns the count of significand (mantissa) bits for float datatypes.
33 // For non-float datatypes, results in a LOG(FATAL).
34 int SignificandWidth(PrimitiveType type);
35
36 // Returns the count of exponent bits for float datatypes.
37 // For non-float datatypes, results in a LOG(FATAL).
38 int ExponentWidth(PrimitiveType type);
39
40 // Returns the exponent of the smallest number which cannot be represented.
41 // For non-float datatypes, results in a LOG(FATAL).
42 int OverflowExponent(PrimitiveType type);
43
44 // Returns the XLA primitive type (eg, F32) corresponding to the given
45 // template parameter native type (eg, float).
46 template <typename NativeT>
NativeToPrimitiveType()47 PrimitiveType NativeToPrimitiveType() {
48 // Make the expression depend on the template parameter NativeT so
49 // that this compile-time error only appears if this function is
50 // instantiated with some concrete type that is not specialized
51 // below.
52 static_assert(!std::is_same<NativeT, NativeT>::value,
53 "Cannot map native type to primitive type.");
54 return PRIMITIVE_TYPE_INVALID;
55 }
56
57 // Declarations of specializations for each native type which correspond to a
58 // XLA primitive type. As an optimization, these are declared inline in the
59 // header.
60 template <>
61 inline PrimitiveType NativeToPrimitiveType<bool>() {
62 return PRED;
63 }
64
65 // Unsigned integer
66 template <>
67 inline PrimitiveType NativeToPrimitiveType<uint8>() {
68 return U8;
69 }
70
71 template <>
72 inline PrimitiveType NativeToPrimitiveType<uint16>() {
73 return U16;
74 }
75
76 template <>
77 inline PrimitiveType NativeToPrimitiveType<uint32>() {
78 return U32;
79 }
80
81 template <>
82 inline PrimitiveType NativeToPrimitiveType<uint64>() {
83 return U64;
84 }
85
86 // Signed integer
87 template <>
88 inline PrimitiveType NativeToPrimitiveType<int8>() {
89 return S8;
90 }
91
92 template <>
93 inline PrimitiveType NativeToPrimitiveType<int16>() {
94 return S16;
95 }
96
97 template <>
98 inline PrimitiveType NativeToPrimitiveType<int32>() {
99 return S32;
100 }
101
102 template <>
103 inline PrimitiveType NativeToPrimitiveType<int64>() {
104 return S64;
105 }
106
107 // Floating point
108 template <>
109 inline PrimitiveType NativeToPrimitiveType<float>() {
110 return F32;
111 }
112
113 template <>
114 inline PrimitiveType NativeToPrimitiveType<double>() {
115 return F64;
116 }
117
118 template <>
119 inline PrimitiveType NativeToPrimitiveType<half>() {
120 return F16;
121 }
122
123 template <>
124 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
125 return BF16;
126 }
127
128 // Complex
129 template <>
130 inline PrimitiveType NativeToPrimitiveType<complex64>() {
131 return C64;
132 }
133
134 template <>
135 inline PrimitiveType NativeToPrimitiveType<complex128>() {
136 return C128;
137 }
138
139 bool IsFloatingPointType(PrimitiveType type);
140
141 bool IsComplexType(PrimitiveType type);
142
143 bool IsSignedIntegralType(PrimitiveType type);
144
145 bool IsUnsignedIntegralType(PrimitiveType type);
146
147 bool IsIntegralType(PrimitiveType type);
148
149 // Returns true if values of the given primitive type are held in array shapes.
150 bool IsArrayType(PrimitiveType primitive_type);
151
152 // Returns the number of bits in the representation for a given type.
153 int BitWidth(PrimitiveType type);
154
155 PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
156
157 PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth);
158
159 // Returns the real, imag component type underlying the given complex type.
160 // LOG(FATAL)'s if complex_type is not complex.
161 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
162
163 // Returns the native type (eg, float) corresponding to the given template
164 // parameter XLA primitive type (eg, F32).
165 template <PrimitiveType>
166 struct PrimitiveTypeToNative;
167
168 // Declarations of specializations for each native type which correspond to a
169 // XLA primitive type.
170 template <>
171 struct PrimitiveTypeToNative<PRED> {
172 using type = bool;
173 };
174
175 // Unsigned integer
176 template <>
177 struct PrimitiveTypeToNative<U8> {
178 using type = uint8;
179 };
180
181 template <>
182 struct PrimitiveTypeToNative<U16> {
183 using type = uint16;
184 };
185
186 template <>
187 struct PrimitiveTypeToNative<U32> {
188 using type = uint32;
189 };
190
191 template <>
192 struct PrimitiveTypeToNative<U64> {
193 using type = uint64;
194 };
195
196 // Signed integer
197 template <>
198 struct PrimitiveTypeToNative<S8> {
199 using type = int8;
200 };
201
202 template <>
203 struct PrimitiveTypeToNative<S16> {
204 using type = int16;
205 };
206
207 template <>
208 struct PrimitiveTypeToNative<S32> {
209 using type = int32;
210 };
211
212 template <>
213 struct PrimitiveTypeToNative<S64> {
214 using type = int64;
215 };
216
217 // Floating point
218 template <>
219 struct PrimitiveTypeToNative<F32> {
220 using type = float;
221 };
222 template <>
223 struct PrimitiveTypeToNative<F64> {
224 using type = double;
225 };
226 template <>
227 struct PrimitiveTypeToNative<F16> {
228 using type = half;
229 };
230
231 template <>
232 struct PrimitiveTypeToNative<BF16> {
233 using type = bfloat16;
234 };
235
236 // Complex
237 template <>
238 struct PrimitiveTypeToNative<C64> {
239 using type = complex64;
240 };
241
242 template <>
243 struct PrimitiveTypeToNative<C128> {
244 using type = complex128;
245 };
246
247 // Returns the lower-case name of the given primitive type.
248 const string& LowercasePrimitiveTypeName(PrimitiveType s);
249
250 // Returns the PrimitiveType matching the given name. The given name is expected
251 // to be lower-case.
252 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
253
254 // Returns true if the given name is a primitive type string (lower-case).
255 bool IsPrimitiveTypeName(absl::string_view name);
256
257 } // namespace primitive_util
258 } // namespace xla
259
260 #endif // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
261