• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with XLA primitive types.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20 
21 #include <type_traits>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace primitive_util {
31 
32 // Returns the count of significand (mantissa) bits for float datatypes.
33 // For non-float datatypes, results in a LOG(FATAL).
34 int SignificandWidth(PrimitiveType type);
35 
36 // The number of exponent bits in a BF16 value.
37 const int kBFloat16ExponentBits = 8;
38 
39 // The number of mantissa bits in a BF16 value. There is an implicit leading
40 // 1, so there is an implicit additional bit of precision.
41 const int kBFloat16MantissaBits = 7;
42 
43 // Returns the XLA primitive type (eg, F32) corresponding to the given
44 // template parameter native type (eg, float).
45 template <typename NativeT>
NativeToPrimitiveType()46 PrimitiveType NativeToPrimitiveType() {
47   // Make the expression depend on the template parameter NativeT so
48   // that this compile-time error only apperas if this function is
49   // instantiated with some concrete type that is not specialized
50   // below.
51   static_assert(!std::is_same<NativeT, NativeT>::value,
52                 "Cannot map native type to primitive type.");
53   return PRIMITIVE_TYPE_INVALID;
54 }
55 
56 // Declarations of specializations for each native type which correspond to a
57 // XLA primitive type.  As an optimization, these are declared inline in the
58 // header.
59 template <>
60 inline PrimitiveType NativeToPrimitiveType<bool>() {
61   return PRED;
62 }
63 
64 // Unsigned integer
65 template <>
66 inline PrimitiveType NativeToPrimitiveType<uint8>() {
67   return U8;
68 }
69 
70 template <>
71 inline PrimitiveType NativeToPrimitiveType<uint16>() {
72   return U16;
73 }
74 
75 template <>
76 inline PrimitiveType NativeToPrimitiveType<uint32>() {
77   return U32;
78 }
79 
80 template <>
81 inline PrimitiveType NativeToPrimitiveType<uint64>() {
82   return U64;
83 }
84 
85 // Signed integer
86 template <>
87 inline PrimitiveType NativeToPrimitiveType<int8>() {
88   return S8;
89 }
90 
91 template <>
92 inline PrimitiveType NativeToPrimitiveType<int16>() {
93   return S16;
94 }
95 
96 template <>
97 inline PrimitiveType NativeToPrimitiveType<int32>() {
98   return S32;
99 }
100 
101 template <>
102 inline PrimitiveType NativeToPrimitiveType<int64>() {
103   return S64;
104 }
105 
106 // Floating point
107 template <>
108 inline PrimitiveType NativeToPrimitiveType<float>() {
109   return F32;
110 }
111 
112 template <>
113 inline PrimitiveType NativeToPrimitiveType<double>() {
114   return F64;
115 }
116 
117 template <>
118 inline PrimitiveType NativeToPrimitiveType<half>() {
119   return F16;
120 }
121 
122 template <>
123 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
124   return BF16;
125 }
126 
127 // Complex
128 template <>
129 inline PrimitiveType NativeToPrimitiveType<complex64>() {
130   return C64;
131 }
132 
133 template <>
134 inline PrimitiveType NativeToPrimitiveType<complex128>() {
135   return C128;
136 }
137 
138 bool IsFloatingPointType(PrimitiveType type);
139 
140 bool IsComplexType(PrimitiveType type);
141 
142 bool IsSignedIntegralType(PrimitiveType type);
143 
144 bool IsUnsignedIntegralType(PrimitiveType type);
145 
146 bool IsIntegralType(PrimitiveType type);
147 
148 // Returns true if values of the given primitive type are held in array shapes.
149 bool IsArrayType(PrimitiveType primitive_type);
150 
151 // Returns the number of bits in the representation for a given type.
152 int BitWidth(PrimitiveType type);
153 
154 PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
155 
156 // Returns the real, imag component type underlying the given complex type.
157 // LOG(FATAL)'s if complex_type is not complex.
158 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
159 
160 // Returns the native type (eg, float) corresponding to the given template
161 // parameter XLA primitive type (eg, F32).
162 template <PrimitiveType>
163 struct PrimitiveTypeToNative;
164 
165 // Declarations of specializations for each native type which correspond to a
166 // XLA primitive type.
167 template <>
168 struct PrimitiveTypeToNative<PRED> {
169   using type = bool;
170 };
171 
172 // Unsigned integer
173 template <>
174 struct PrimitiveTypeToNative<U8> {
175   using type = uint8;
176 };
177 
178 template <>
179 struct PrimitiveTypeToNative<U16> {
180   using type = uint16;
181 };
182 
183 template <>
184 struct PrimitiveTypeToNative<U32> {
185   using type = uint32;
186 };
187 
188 template <>
189 struct PrimitiveTypeToNative<U64> {
190   using type = uint64;
191 };
192 
193 // Signed integer
194 template <>
195 struct PrimitiveTypeToNative<S8> {
196   using type = int8;
197 };
198 
199 template <>
200 struct PrimitiveTypeToNative<S16> {
201   using type = int16;
202 };
203 
204 template <>
205 struct PrimitiveTypeToNative<S32> {
206   using type = int32;
207 };
208 
209 template <>
210 struct PrimitiveTypeToNative<S64> {
211   using type = int64;
212 };
213 
214 // Floating point
215 template <>
216 struct PrimitiveTypeToNative<F32> {
217   using type = float;
218 };
219 template <>
220 struct PrimitiveTypeToNative<F64> {
221   using type = double;
222 };
223 template <>
224 struct PrimitiveTypeToNative<F16> {
225   using type = half;
226 };
227 
228 template <>
229 struct PrimitiveTypeToNative<BF16> {
230   using type = bfloat16;
231 };
232 
233 // Complex
234 template <>
235 struct PrimitiveTypeToNative<C64> {
236   using type = complex64;
237 };
238 
239 template <>
240 struct PrimitiveTypeToNative<C128> {
241   using type = complex128;
242 };
243 
244 // Returns the lower-case name of the given primitive type.
245 const string& LowercasePrimitiveTypeName(PrimitiveType s);
246 
247 // Returns the PrimitiveType matching the given name. The given name is expected
248 // to be lower-case.
249 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
250 
251 // Returns true if the given name is a primitive type string (lower-case).
252 bool IsPrimitiveTypeName(absl::string_view name);
253 
254 }  // namespace primitive_util
255 }  // namespace xla
256 
257 #endif  // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
258