• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for dealing with XLA primitive types.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
19 #define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
20 
21 #include <type_traits>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace primitive_util {
31 
32 // Returns the count of significand (mantissa) bits for float datatypes.
33 // For non-float datatypes, results in a LOG(FATAL).
34 int SignificandWidth(PrimitiveType type);
35 
36 // Returns the count of exponent bits for float datatypes.
37 // For non-float datatypes, results in a LOG(FATAL).
38 int ExponentWidth(PrimitiveType type);
39 
40 // Returns the exponent of the smallest number which cannot be represented.
41 // For non-float datatypes, results in a LOG(FATAL).
42 int OverflowExponent(PrimitiveType type);
43 
44 // Returns the XLA primitive type (eg, F32) corresponding to the given
45 // template parameter native type (eg, float).
46 template <typename NativeT>
NativeToPrimitiveType()47 PrimitiveType NativeToPrimitiveType() {
48   // Make the expression depend on the template parameter NativeT so
49   // that this compile-time error only appears if this function is
50   // instantiated with some concrete type that is not specialized
51   // below.
52   static_assert(!std::is_same<NativeT, NativeT>::value,
53                 "Cannot map native type to primitive type.");
54   return PRIMITIVE_TYPE_INVALID;
55 }
56 
57 // Declarations of specializations for each native type which correspond to a
58 // XLA primitive type.  As an optimization, these are declared inline in the
59 // header.
60 template <>
61 inline PrimitiveType NativeToPrimitiveType<bool>() {
62   return PRED;
63 }
64 
65 // Unsigned integer
66 template <>
67 inline PrimitiveType NativeToPrimitiveType<uint8>() {
68   return U8;
69 }
70 
71 template <>
72 inline PrimitiveType NativeToPrimitiveType<uint16>() {
73   return U16;
74 }
75 
76 template <>
77 inline PrimitiveType NativeToPrimitiveType<uint32>() {
78   return U32;
79 }
80 
81 template <>
82 inline PrimitiveType NativeToPrimitiveType<uint64>() {
83   return U64;
84 }
85 
86 // Signed integer
87 template <>
88 inline PrimitiveType NativeToPrimitiveType<int8>() {
89   return S8;
90 }
91 
92 template <>
93 inline PrimitiveType NativeToPrimitiveType<int16>() {
94   return S16;
95 }
96 
97 template <>
98 inline PrimitiveType NativeToPrimitiveType<int32>() {
99   return S32;
100 }
101 
102 template <>
103 inline PrimitiveType NativeToPrimitiveType<int64>() {
104   return S64;
105 }
106 
107 // Floating point
108 template <>
109 inline PrimitiveType NativeToPrimitiveType<float>() {
110   return F32;
111 }
112 
113 template <>
114 inline PrimitiveType NativeToPrimitiveType<double>() {
115   return F64;
116 }
117 
118 template <>
119 inline PrimitiveType NativeToPrimitiveType<half>() {
120   return F16;
121 }
122 
123 template <>
124 inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
125   return BF16;
126 }
127 
128 // Complex
129 template <>
130 inline PrimitiveType NativeToPrimitiveType<complex64>() {
131   return C64;
132 }
133 
134 template <>
135 inline PrimitiveType NativeToPrimitiveType<complex128>() {
136   return C128;
137 }
138 
139 bool IsFloatingPointType(PrimitiveType type);
140 
141 bool IsComplexType(PrimitiveType type);
142 
143 bool IsSignedIntegralType(PrimitiveType type);
144 
145 bool IsUnsignedIntegralType(PrimitiveType type);
146 
147 bool IsIntegralType(PrimitiveType type);
148 
149 // Returns true if values of the given primitive type are held in array shapes.
150 bool IsArrayType(PrimitiveType primitive_type);
151 
152 // Returns the number of bits in the representation for a given type.
153 int BitWidth(PrimitiveType type);
154 
155 PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
156 
157 PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth);
158 
159 // Returns the real, imag component type underlying the given complex type.
160 // LOG(FATAL)'s if complex_type is not complex.
161 PrimitiveType ComplexComponentType(PrimitiveType complex_type);
162 
163 // Returns the native type (eg, float) corresponding to the given template
164 // parameter XLA primitive type (eg, F32).
165 template <PrimitiveType>
166 struct PrimitiveTypeToNative;
167 
168 // Declarations of specializations for each native type which correspond to a
169 // XLA primitive type.
170 template <>
171 struct PrimitiveTypeToNative<PRED> {
172   using type = bool;
173 };
174 
175 // Unsigned integer
176 template <>
177 struct PrimitiveTypeToNative<U8> {
178   using type = uint8;
179 };
180 
181 template <>
182 struct PrimitiveTypeToNative<U16> {
183   using type = uint16;
184 };
185 
186 template <>
187 struct PrimitiveTypeToNative<U32> {
188   using type = uint32;
189 };
190 
191 template <>
192 struct PrimitiveTypeToNative<U64> {
193   using type = uint64;
194 };
195 
196 // Signed integer
197 template <>
198 struct PrimitiveTypeToNative<S8> {
199   using type = int8;
200 };
201 
202 template <>
203 struct PrimitiveTypeToNative<S16> {
204   using type = int16;
205 };
206 
207 template <>
208 struct PrimitiveTypeToNative<S32> {
209   using type = int32;
210 };
211 
212 template <>
213 struct PrimitiveTypeToNative<S64> {
214   using type = int64;
215 };
216 
217 // Floating point
218 template <>
219 struct PrimitiveTypeToNative<F32> {
220   using type = float;
221 };
222 template <>
223 struct PrimitiveTypeToNative<F64> {
224   using type = double;
225 };
226 template <>
227 struct PrimitiveTypeToNative<F16> {
228   using type = half;
229 };
230 
231 template <>
232 struct PrimitiveTypeToNative<BF16> {
233   using type = bfloat16;
234 };
235 
236 // Complex
237 template <>
238 struct PrimitiveTypeToNative<C64> {
239   using type = complex64;
240 };
241 
242 template <>
243 struct PrimitiveTypeToNative<C128> {
244   using type = complex128;
245 };
246 
247 // Returns the lower-case name of the given primitive type.
248 const string& LowercasePrimitiveTypeName(PrimitiveType s);
249 
250 // Returns the PrimitiveType matching the given name. The given name is expected
251 // to be lower-case.
252 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name);
253 
254 // Returns true if the given name is a primitive type string (lower-case).
255 bool IsPrimitiveTypeName(absl::string_view name);
256 
257 }  // namespace primitive_util
258 }  // namespace xla
259 
260 #endif  // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
261