• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/primitive_util.h"
17 
18 #include <limits>
19 
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/numbers.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 namespace primitive_util {
29 
SignificandWidth(PrimitiveType type)30 int SignificandWidth(PrimitiveType type) {
31   switch (type) {
32     case F32:
33       return std::numeric_limits<float>::digits;
34     case F64:
35       return std::numeric_limits<double>::digits;
36     case BF16:
37       return std::numeric_limits<bfloat16>::digits;
38     case F16:
39       return std::numeric_limits<half>::digits;
40     default:
41       LOG(FATAL) << "Not a floating data type " << type;
42   }
43 }
44 
ExponentWidth(PrimitiveType type)45 int ExponentWidth(PrimitiveType type) {
46   // Per the IEEE-754 standard: a floating point type is stored as a sign bit, a
47   // biased exponent and a trailing significand field.
48   int total_bit_width = BitWidth(type);
49   // This field contains all bits in the significand other than the leading
50   // digit which is implied by the exponent.
51   int trailing_significand_field_width = SignificandWidth(type) - 1;
52   // The sign is encoded with a single bit.
53   int kSignBitWidth = 1;
54   // The remaining bits are used for encoding the biased exponent.
55   return total_bit_width - (trailing_significand_field_width + kSignBitWidth);
56 }
57 
OverflowExponent(PrimitiveType type)58 int OverflowExponent(PrimitiveType type) {
59   // |std::numeric_limits<float>::max_exponent| is defined as: "Maximum positive
60   // integer such that radix raised to the power one less than that integer is a
61   // representable finite floating-point number." as such it does not actually
62   // yield the maximum exponent but the exponent of the first integer which
63   // overflows.
64   switch (type) {
65     case F32:
66       return std::numeric_limits<float>::max_exponent;
67     case F64:
68       return std::numeric_limits<double>::max_exponent;
69     case BF16:
70       return std::numeric_limits<bfloat16>::max_exponent;
71     case F16:
72       return std::numeric_limits<half>::max_exponent;
73     default:
74       LOG(FATAL) << "Not a floating data type " << type;
75   }
76 }
77 
IsFloatingPointType(PrimitiveType type)78 bool IsFloatingPointType(PrimitiveType type) {
79   return type == F16 || type == F32 || type == F64 || type == BF16;
80 }
81 
IsComplexType(PrimitiveType type)82 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
83 
IsSignedIntegralType(PrimitiveType type)84 bool IsSignedIntegralType(PrimitiveType type) {
85   return type == S8 || type == S16 || type == S32 || type == S64;
86 }
87 
IsUnsignedIntegralType(PrimitiveType type)88 bool IsUnsignedIntegralType(PrimitiveType type) {
89   return type == U8 || type == U16 || type == U32 || type == U64;
90 }
91 
IsIntegralType(PrimitiveType type)92 bool IsIntegralType(PrimitiveType type) {
93   return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
94 }
95 
BitWidth(PrimitiveType type)96 int BitWidth(PrimitiveType type) {
97   switch (type) {
98     case PRED:
99       return 1;
100 
101     case S8:
102     case U8:
103       return 8;
104 
105     case S16:
106     case U16:
107     case F16:
108     case BF16:
109       return 16;
110 
111     case U32:
112     case S32:
113     case F32:
114       return 32;
115 
116     case U64:
117     case S64:
118     case F64:
119     case C64:
120       return 64;
121 
122     case C128:
123       return 128;
124 
125     case TUPLE:
126       LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
127 
128     case OPAQUE_TYPE:
129       LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
130 
131     default:
132       LOG(FATAL) << "Unhandled primitive type " << type;
133   }
134 }
135 
ByteWidth(PrimitiveType type)136 int ByteWidth(PrimitiveType type) { return CeilOfRatio(BitWidth(type), 8); }
137 
UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth)138 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
139   switch (src_bitwidth) {
140     case 8:
141       return xla::U8;
142     case 16:
143       return xla::U16;
144     case 32:
145       return xla::U32;
146     case 64:
147       return xla::U64;
148     default:
149       return xla::PRIMITIVE_TYPE_INVALID;
150   }
151 }
152 
SignedIntegralTypeForBitWidth(int64_t src_bitwidth)153 xla::PrimitiveType SignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
154   switch (src_bitwidth) {
155     case 8:
156       return xla::S8;
157     case 16:
158       return xla::S16;
159     case 32:
160       return xla::S32;
161     case 64:
162       return xla::S64;
163     default:
164       return xla::PRIMITIVE_TYPE_INVALID;
165   }
166 }
167 
ComplexComponentType(PrimitiveType complex_type)168 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
169   switch (complex_type) {
170     case C64:
171       return F32;
172     case C128:
173       return F64;
174     default:
175       LOG(FATAL) << "Primitive type is not complex: "
176                  << PrimitiveType_Name(complex_type);
177   }
178 }
179 
IsArrayType(PrimitiveType primitive_type)180 bool IsArrayType(PrimitiveType primitive_type) {
181   return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
182          primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
183 }
184 
185 // Class to memoize the computation of
186 //   absl::AsciiStrToLower(PrimitiveType_Name(p))
187 // for all PrimitiveType values "p"
188 //
189 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
190 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
191 class PrimitiveTypeNameGenerator {
192  public:
PrimitiveTypeNameGenerator()193   PrimitiveTypeNameGenerator() {
194     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
195       if (i == static_cast<int>(OPAQUE_TYPE)) {
196         lowercase_name_[i] = "opaque";
197       } else if (PrimitiveType_IsValid(i)) {
198         lowercase_name_[i] = absl::AsciiStrToLower(
199             PrimitiveType_Name(static_cast<PrimitiveType>(i)));
200       }
201     }
202   }
LowercaseName(PrimitiveType t)203   const string& LowercaseName(PrimitiveType t) {
204     return lowercase_name_[static_cast<int>(t)];
205   }
206 
207  private:
208   string lowercase_name_[PrimitiveType_ARRAYSIZE];
209 };
210 
LowercasePrimitiveTypeName(PrimitiveType s)211 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
212   static auto* gen = new PrimitiveTypeNameGenerator();
213   return gen->LowercaseName(s);
214 }
215 
216 namespace {
217 
218 // Returns a map from lower-case primitive type name to primitive type.
219 //
220 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
221 // the xla::OPAQUE_TYPE enumerator.
GetPrimitiveTypeStringMap()222 const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
223   static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
224     static auto* map = new std::unordered_map<string, PrimitiveType>;
225     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
226       if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
227         auto value = static_cast<PrimitiveType>(i);
228         (*map)[LowercasePrimitiveTypeName(value)] = value;
229       }
230     }
231     (*map)["opaque"] = OPAQUE_TYPE;
232     return map;
233   }();
234   return *name_to_type;
235 }
236 
237 }  // namespace
238 
StringToPrimitiveType(absl::string_view name)239 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
240   const auto& map = GetPrimitiveTypeStringMap();
241   auto found = map.find(string(name));
242   if (found == map.end()) {
243     return InvalidArgument("Invalid element type string: \"%s\".", name);
244   }
245   return found->second;
246 }
247 
IsPrimitiveTypeName(absl::string_view name)248 bool IsPrimitiveTypeName(absl::string_view name) {
249   const auto& map = GetPrimitiveTypeStringMap();
250   auto found = map.find(string(name));
251   return found != map.end();
252 }
253 
254 }  // namespace primitive_util
255 }  // namespace xla
256