1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/primitive_util.h" 17 18 #include "absl/strings/ascii.h" 19 #include "absl/strings/numbers.h" 20 #include "tensorflow/compiler/xla/util.h" 21 #include "tensorflow/compiler/xla/xla_data.pb.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace xla { 25 namespace primitive_util { 26 SignificandWidth(PrimitiveType type)27int SignificandWidth(PrimitiveType type) { 28 switch (type) { 29 case F32: 30 return std::numeric_limits<float>::digits; 31 case F64: 32 return std::numeric_limits<double>::digits; 33 case BF16: 34 return kBFloat16MantissaBits + 1; 35 case F16: 36 return 11; 37 default: 38 LOG(FATAL) << "Not a floating data type " << type; 39 } 40 } 41 IsFloatingPointType(PrimitiveType type)42bool IsFloatingPointType(PrimitiveType type) { 43 return type == F16 || type == F32 || type == F64 || type == BF16; 44 } 45 IsComplexType(PrimitiveType type)46bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; } 47 IsSignedIntegralType(PrimitiveType type)48bool IsSignedIntegralType(PrimitiveType type) { 49 return type == S8 || type == S16 || type == S32 || type == S64; 50 } 51 IsUnsignedIntegralType(PrimitiveType type)52bool IsUnsignedIntegralType(PrimitiveType type) { 53 return type == U8 || type == U16 || type == U32 || type == U64; 54 } 55 IsIntegralType(PrimitiveType type)56bool IsIntegralType(PrimitiveType type) { 57 return IsUnsignedIntegralType(type) || IsSignedIntegralType(type); 58 } 59 BitWidth(PrimitiveType type)60int BitWidth(PrimitiveType type) { 61 switch (type) { 62 case PRED: 63 return 1; 64 65 case S8: 66 case U8: 67 return 8; 68 69 case S16: 70 case U16: 71 case F16: 72 case BF16: 73 return 16; 74 75 case U32: 76 case S32: 77 case F32: 78 return 32; 79 80 case U64: 81 case S64: 82 case F64: 83 case C64: 84 return 64; 85 86 case C128: 87 return 128; 88 89 case TUPLE: 90 LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; 91 92 case OPAQUE_TYPE: 93 LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth"; 94 95 default: 96 LOG(FATAL) << "Unhandled primitive type " << type; 97 } 98 } 99 UnsignedIntegralTypeForBitWidth(int64 src_bitwidth)100xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { 101 switch (src_bitwidth) { 102 case 8: 103 return xla::U8; 104 case 16: 105 return xla::U16; 106 case 32: 107 return xla::U32; 108 case 64: 109 return xla::U64; 110 default: 111 return xla::PRIMITIVE_TYPE_INVALID; 112 } 113 } 114 ComplexComponentType(PrimitiveType complex_type)115PrimitiveType ComplexComponentType(PrimitiveType complex_type) { 116 switch (complex_type) { 117 case C64: 118 return F32; 119 case C128: 120 return F64; 121 default: 122 LOG(FATAL) << "Primitive type is not complex: " 123 << PrimitiveType_Name(complex_type); 124 } 125 } 126 IsArrayType(PrimitiveType primitive_type)127bool IsArrayType(PrimitiveType primitive_type) { 128 return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && 129 primitive_type != OPAQUE_TYPE && primitive_type != TOKEN; 130 } 131 132 // Class to memoize the computation of 133 // absl::AsciiStrToLower(PrimitiveType_Name(p)) 134 // for all PrimitiveType values "p" 135 // 136 // xla::OPAQUE_TYPE canonically maps to the string "opaque" -- the only reason 137 // it's called OPAQUE_TYPE is to avoid clashing with a windows.h macro. 138 class PrimitiveTypeNameGenerator { 139 public: PrimitiveTypeNameGenerator()140 PrimitiveTypeNameGenerator() { 141 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { 142 if (i == static_cast<int>(OPAQUE_TYPE)) { 143 lowercase_name_[i] = "opaque"; 144 } else if (PrimitiveType_IsValid(i)) { 145 lowercase_name_[i] = absl::AsciiStrToLower( 146 PrimitiveType_Name(static_cast<PrimitiveType>(i))); 147 } 148 } 149 } LowercaseName(PrimitiveType t)150 const string& LowercaseName(PrimitiveType t) { 151 return lowercase_name_[static_cast<int>(t)]; 152 } 153 154 private: 155 string lowercase_name_[PrimitiveType_ARRAYSIZE]; 156 }; 157 LowercasePrimitiveTypeName(PrimitiveType s)158const string& LowercasePrimitiveTypeName(PrimitiveType s) { 159 static auto* gen = new PrimitiveTypeNameGenerator(); 160 return gen->LowercaseName(s); 161 } 162 163 namespace { 164 165 // Returns a map from lower-case primitive type name to primitive type. 166 // 167 // Due to Postel's Law considerations, both "opaque" and "opaque_type" map to 168 // the xla::OPAQUE_TYPE enumerator. GetPrimitiveTypeStringMap()169const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() { 170 static std::unordered_map<string, PrimitiveType>* name_to_type = [] { 171 static auto* map = new std::unordered_map<string, PrimitiveType>; 172 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { 173 if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { 174 auto value = static_cast<PrimitiveType>(i); 175 (*map)[LowercasePrimitiveTypeName(value)] = value; 176 } 177 } 178 (*map)["opaque"] = OPAQUE_TYPE; 179 return map; 180 }(); 181 return *name_to_type; 182 } 183 184 } // namespace 185 StringToPrimitiveType(absl::string_view name)186StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) { 187 const auto& map = GetPrimitiveTypeStringMap(); 188 auto found = map.find(string(name)); 189 if (found == map.end()) { 190 return InvalidArgument("Invalid element type string: \"%s\".", name); 191 } 192 return found->second; 193 } 194 IsPrimitiveTypeName(absl::string_view name)195bool IsPrimitiveTypeName(absl::string_view name) { 196 const auto& map = GetPrimitiveTypeStringMap(); 197 auto found = map.find(string(name)); 198 return found != map.end(); 199 } 200 201 } // namespace primitive_util 202 } // namespace xla 203