• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/primitive_util.h"
17 
18 #include <limits>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/numbers.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace xla {
29 namespace primitive_util {
30 
SignificandWidth(PrimitiveType type)31 int SignificandWidth(PrimitiveType type) {
32   switch (type) {
33     case F32:
34       return std::numeric_limits<float>::digits;
35     case F64:
36       return std::numeric_limits<double>::digits;
37     case BF16:
38       return std::numeric_limits<bfloat16>::digits;
39     case F16:
40       return std::numeric_limits<half>::digits;
41     default:
42       LOG(FATAL) << "Not a floating data type " << type;
43   }
44 }
45 
ExponentWidth(PrimitiveType type)46 int ExponentWidth(PrimitiveType type) {
47   // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
48   // biased exponent and a trailing significand field.
49   int total_bit_width = BitWidth(type);
50   // This field contains all bits in the significand other than the leading
51   // digit which is implied by the exponent.
52   int trailing_significand_field_width = SignificandWidth(type) - 1;
53   // The sign is encoded with a single bit.
54   int kSignBitWidth = 1;
55   // The remaining bits are used for encoding the biased exponent.
56   return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
57 }
58 
OverflowExponent(PrimitiveType type)59 int OverflowExponent(PrimitiveType type) {
60   // |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
61   // integer such that radix raised to the power one less than that integer is a
62   // representable finite floating-point number." as such it does not actually
63   // yield the maximum exponent but the exponent of the first integer which
64   // overflows.
65   switch (type) {
66     case F32:
67       return std::numeric_limits<float>::max_exponent;
68     case F64:
69       return std::numeric_limits<double>::max_exponent;
70     case BF16:
71       return std::numeric_limits<bfloat16>::max_exponent;
72     case F16:
73       return std::numeric_limits<half>::max_exponent;
74     default:
75       LOG(FATAL) << "Not a floating data type " << type;
76   }
77 }
78 
IsFloatingPointType(PrimitiveType type)79 bool IsFloatingPointType(PrimitiveType type) {
80   return type == F16 || type == F32 || type == F64 || type == BF16;
81 }
82 
IsComplexType(PrimitiveType type)83 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
84 
IsSignedIntegralType(PrimitiveType type)85 bool IsSignedIntegralType(PrimitiveType type) {
86   return type == S8 || type == S16 || type == S32 || type == S64;
87 }
88 
IsUnsignedIntegralType(PrimitiveType type)89 bool IsUnsignedIntegralType(PrimitiveType type) {
90   return type == U8 || type == U16 || type == U32 || type == U64;
91 }
92 
IsIntegralType(PrimitiveType type)93 bool IsIntegralType(PrimitiveType type) {
94   return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
95 }
96 
BitWidth(PrimitiveType type)97 int BitWidth(PrimitiveType type) {
98   switch (type) {
99     case PRED:
100       return 1;
101 
102     case S8:
103     case U8:
104       return 8;
105 
106     case S16:
107     case U16:
108     case F16:
109     case BF16:
110       return 16;
111 
112     case U32:
113     case S32:
114     case F32:
115       return 32;
116 
117     case U64:
118     case S64:
119     case F64:
120     case C64:
121       return 64;
122 
123     case C128:
124       return 128;
125 
126     case TUPLE:
127       LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
128 
129     case OPAQUE_TYPE:
130       LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
131 
132     default:
133       LOG(FATAL) << "Unhandled primitive type " << type;
134   }
135 }
136 
ByteWidth(PrimitiveType type)137 int ByteWidth(PrimitiveType type) { return CeilOfRatio(BitWidth(type), 8); }
138 
UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth)139 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
140   switch (src_bitwidth) {
141     case 8:
142       return xla::U8;
143     case 16:
144       return xla::U16;
145     case 32:
146       return xla::U32;
147     case 64:
148       return xla::U64;
149     default:
150       return xla::PRIMITIVE_TYPE_INVALID;
151   }
152 }
153 
SignedIntegralTypeForBitWidth(int64_t src_bitwidth)154 xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
155   switch (src_bitwidth) {
156     case 8:
157       return xla::S8;
158     case 16:
159       return xla::S16;
160     case 32:
161       return xla::S32;
162     case 64:
163       return xla::S64;
164     default:
165       return xla::PRIMITIVE_TYPE_INVALID;
166   }
167 }
168 
ComplexComponentType(PrimitiveType complex_type)169 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
170   switch (complex_type) {
171     case C64:
172       return F32;
173     case C128:
174       return F64;
175     default:
176       LOG(FATAL) << "Primitive type is not complex: "
177                  << PrimitiveType_Name(complex_type);
178   }
179 }
180 
181 // Class to memoize the computation of
182 //   absl::AsciiStrToLower(PrimitiveType_Name(p))
183 // for all PrimitiveType values "p"
184 //
185 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
186 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
187 class PrimitiveTypeNameGenerator {
188  public:
PrimitiveTypeNameGenerator()189   PrimitiveTypeNameGenerator() {
190     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
191       if (i == static_cast<int>(OPAQUE_TYPE)) {
192         lowercase_name_[i] = "opaque";
193       } else if (PrimitiveType_IsValid(i)) {
194         lowercase_name_[i] = absl::AsciiStrToLower(
195             PrimitiveType_Name(static_cast<PrimitiveType>(i)));
196       }
197     }
198   }
LowercaseName(PrimitiveType t)199   const std::string& LowercaseName(PrimitiveType t) {
200     return lowercase_name_[static_cast<int>(t)];
201   }
202 
203  private:
204   std::string lowercase_name_[PrimitiveType_ARRAYSIZE];
205 };
206 
LowercasePrimitiveTypeName(PrimitiveType s)207 const std::string& LowercasePrimitiveTypeName(PrimitiveType s) {
208   static auto* gen = new PrimitiveTypeNameGenerator();
209   return gen->LowercaseName(s);
210 }
211 
212 namespace {
213 
214 // Returns a map from lower-case primitive type name to primitive type.
215 //
216 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
217 // the xla::OPAQUE_TYPE enumerator.
218 const absl::flat_hash_map<std::string, PrimitiveType>&
GetPrimitiveTypeStringMap()219 GetPrimitiveTypeStringMap() {
220   static absl::flat_hash_map<std::string, PrimitiveType>* name_to_type = [] {
221     static auto* map = new absl::flat_hash_map<std::string, PrimitiveType>;
222     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
223       if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
224         auto value = static_cast<PrimitiveType>(i);
225         (*map)[LowercasePrimitiveTypeName(value)] = value;
226       }
227     }
228     (*map)["opaque"] = OPAQUE_TYPE;
229     return map;
230   }();
231   return *name_to_type;
232 }
233 
234 }  // namespace
235 
StringToPrimitiveType(absl::string_view name)236 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
237   const auto& map = GetPrimitiveTypeStringMap();
238   auto found = map.find(std::string(name));
239   if (found == map.end()) {
240     return InvalidArgument("Invalid element type string: \"%s\".", name);
241   }
242   return found->second;
243 }
244 
IsPrimitiveTypeName(absl::string_view name)245 bool IsPrimitiveTypeName(absl::string_view name) {
246   const auto& map = GetPrimitiveTypeStringMap();
247   auto found = map.find(std::string(name));
248   return found != map.end();
249 }
250 
251 }  // namespace primitive_util
252 }  // namespace xla
253