• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/primitive_util.h"
17 
18 #include "absl/strings/ascii.h"
19 #include "absl/strings/numbers.h"
20 #include "tensorflow/compiler/xla/util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace xla {
25 namespace primitive_util {
26 
SignificandWidth(PrimitiveType type)27 int SignificandWidth(PrimitiveType type) {
28   switch (type) {
29     case F32:
30       return std::numeric_limits<float>::digits;
31     case F64:
32       return std::numeric_limits<double>::digits;
33     case BF16:
34       return kBFloat16MantissaBits + 1;
35     case F16:
36       return 11;
37     default:
38       LOG(FATAL) << "Not a floating data type " << type;
39   }
40 }
41 
IsFloatingPointType(PrimitiveType type)42 bool IsFloatingPointType(PrimitiveType type) {
43   return type == F16 || type == F32 || type == F64 || type == BF16;
44 }
45 
IsComplexType(PrimitiveType type)46 bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
47 
IsSignedIntegralType(PrimitiveType type)48 bool IsSignedIntegralType(PrimitiveType type) {
49   return type == S8 || type == S16 || type == S32 || type == S64;
50 }
51 
IsUnsignedIntegralType(PrimitiveType type)52 bool IsUnsignedIntegralType(PrimitiveType type) {
53   return type == U8 || type == U16 || type == U32 || type == U64;
54 }
55 
IsIntegralType(PrimitiveType type)56 bool IsIntegralType(PrimitiveType type) {
57   return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
58 }
59 
BitWidth(PrimitiveType type)60 int BitWidth(PrimitiveType type) {
61   switch (type) {
62     case PRED:
63       return 1;
64 
65     case S8:
66     case U8:
67       return 8;
68 
69     case S16:
70     case U16:
71     case F16:
72     case BF16:
73       return 16;
74 
75     case U32:
76     case S32:
77     case F32:
78       return 32;
79 
80     case U64:
81     case S64:
82     case F64:
83     case C64:
84       return 64;
85 
86     case C128:
87       return 128;
88 
89     case TUPLE:
90       LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
91 
92     case OPAQUE_TYPE:
93       LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
94 
95     default:
96       LOG(FATAL) << "Unhandled primitive type " << type;
97   }
98 }
99 
UnsignedIntegralTypeForBitWidth(int64 src_bitwidth)100 xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
101   switch (src_bitwidth) {
102     case 8:
103       return xla::U8;
104     case 16:
105       return xla::U16;
106     case 32:
107       return xla::U32;
108     case 64:
109       return xla::U64;
110     default:
111       return xla::PRIMITIVE_TYPE_INVALID;
112   }
113 }
114 
ComplexComponentType(PrimitiveType complex_type)115 PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
116   switch (complex_type) {
117     case C64:
118       return F32;
119     case C128:
120       return F64;
121     default:
122       LOG(FATAL) << "Primitive type is not complex: "
123                  << PrimitiveType_Name(complex_type);
124   }
125 }
126 
IsArrayType(PrimitiveType primitive_type)127 bool IsArrayType(PrimitiveType primitive_type) {
128   return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
129          primitive_type != OPAQUE_TYPE && primitive_type != TOKEN;
130 }
131 
132 // Class to memoize the computation of
133 //   absl::AsciiStrToLower(PrimitiveType_Name(p))
134 // for all PrimitiveType values "p"
135 //
136 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason
137 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro.
138 class PrimitiveTypeNameGenerator {
139  public:
PrimitiveTypeNameGenerator()140   PrimitiveTypeNameGenerator() {
141     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
142       if (i == static_cast<int>(OPAQUE_TYPE)) {
143         lowercase_name_[i] = "opaque";
144       } else if (PrimitiveType_IsValid(i)) {
145         lowercase_name_[i] = absl::AsciiStrToLower(
146             PrimitiveType_Name(static_cast<PrimitiveType>(i)));
147       }
148     }
149   }
LowercaseName(PrimitiveType t)150   const string& LowercaseName(PrimitiveType t) {
151     return lowercase_name_[static_cast<int>(t)];
152   }
153 
154  private:
155   string lowercase_name_[PrimitiveType_ARRAYSIZE];
156 };
157 
LowercasePrimitiveTypeName(PrimitiveType s)158 const string& LowercasePrimitiveTypeName(PrimitiveType s) {
159   static auto* gen = new PrimitiveTypeNameGenerator();
160   return gen->LowercaseName(s);
161 }
162 
163 namespace {
164 
165 // Returns a map from lower-case primitive type name to primitive type.
166 //
167 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to
168 // the xla::OPAQUE_TYPE enumerator.
GetPrimitiveTypeStringMap()169 const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() {
170   static std::unordered_map<string, PrimitiveType>* name_to_type = [] {
171     static auto* map = new std::unordered_map<string, PrimitiveType>;
172     for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) {
173       if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) {
174         auto value = static_cast<PrimitiveType>(i);
175         (*map)[LowercasePrimitiveTypeName(value)] = value;
176       }
177     }
178     (*map)["opaque"] = OPAQUE_TYPE;
179     return map;
180   }();
181   return *name_to_type;
182 }
183 
184 }  // namespace
185 
StringToPrimitiveType(absl::string_view name)186 StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) {
187   const auto& map = GetPrimitiveTypeStringMap();
188   auto found = map.find(string(name));
189   if (found == map.end()) {
190     return InvalidArgument("Invalid element type string: \"%s\".", name);
191   }
192   return found->second;
193 }
194 
IsPrimitiveTypeName(absl::string_view name)195 bool IsPrimitiveTypeName(absl::string_view name) {
196   const auto& map = GetPrimitiveTypeStringMap();
197   auto found = map.find(string(name));
198   return found != map.end();
199 }
200 
201 }  // namespace primitive_util
202 }  // namespace xla
203