1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/primitive_util.h" 17 18 #include "absl/strings/ascii.h" 19 #include "absl/strings/numbers.h" 20 #include "tensorflow/compiler/xla/util.h" 21 #include "tensorflow/compiler/xla/xla_data.pb.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace xla { 25 namespace primitive_util { 26 SignificandWidth(PrimitiveType type)27int SignificandWidth(PrimitiveType type) { 28 switch (type) { 29 case F32: 30 return std::numeric_limits<float>::digits; 31 case F64: 32 return std::numeric_limits<double>::digits; 33 case BF16: 34 return kBFloat16MantissaBits + 1; 35 case F16: 36 return 11; 37 default: 38 LOG(FATAL) << "Not a floating data type " << type; 39 } 40 } 41 IsFloatingPointType(PrimitiveType type)42bool IsFloatingPointType(PrimitiveType type) { 43 return type == F16 || type == F32 || type == F64 || type == BF16; 44 } 45 IsComplexType(PrimitiveType type)46bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; } 47 IsSignedIntegralType(PrimitiveType type)48bool IsSignedIntegralType(PrimitiveType type) { 49 return type == S8 || type == S16 || type == S32 || type == S64; 50 } 51 IsUnsignedIntegralType(PrimitiveType type)52bool IsUnsignedIntegralType(PrimitiveType type) { 53 return type == U8 || type == U16 || type == U32 || type == U64; 54 } 55 IsIntegralType(PrimitiveType type)56bool IsIntegralType(PrimitiveType type) { 57 return IsUnsignedIntegralType(type) || IsSignedIntegralType(type); 58 } 59 BitWidth(PrimitiveType type)60int BitWidth(PrimitiveType type) { 61 switch (type) { 62 case PRED: 63 return 1; 64 65 case S8: 66 case U8: 67 return 8; 68 69 case S16: 70 case U16: 71 case F16: 72 case BF16: 73 return 16; 74 75 case U32: 76 case S32: 77 case F32: 78 return 32; 79 80 case U64: 81 case S64: 82 case F64: 83 case C64: 84 return 64; 85 86 case C128: 87 return 128; 88 89 case TUPLE: 90 LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; 91 92 case OPAQUE: 93 LOG(FATAL) << "OPAQUE is an invalid type for BitWidth"; 94 95 default: 96 LOG(FATAL) << "Unhandled primitive type " << type; 97 } 98 } 99 UnsignedIntegralTypeForBitWidth(int64 src_bitwidth)100xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { 101 switch (src_bitwidth) { 102 case 8: 103 return xla::U8; 104 case 16: 105 return xla::U16; 106 case 32: 107 return xla::U32; 108 case 64: 109 return xla::U64; 110 default: 111 return xla::PRIMITIVE_TYPE_INVALID; 112 } 113 } 114 ComplexComponentType(PrimitiveType complex_type)115PrimitiveType ComplexComponentType(PrimitiveType complex_type) { 116 switch (complex_type) { 117 case C64: 118 return F32; 119 case C128: 120 return F64; 121 default: 122 LOG(FATAL) << "Primitive type is not complex: " 123 << PrimitiveType_Name(complex_type); 124 } 125 } 126 IsArrayType(PrimitiveType primitive_type)127bool IsArrayType(PrimitiveType primitive_type) { 128 return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && 129 primitive_type != OPAQUE && primitive_type != TOKEN; 130 } 131 132 // Class to memoize the computation of 133 // absl::AsciiStrToLower(PrimitiveType_Name(p)) 134 // for all PrimitiveType values "p" 135 class PrimitiveTypeNameGenerator { 136 public: PrimitiveTypeNameGenerator()137 PrimitiveTypeNameGenerator() { 138 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { 139 if (PrimitiveType_IsValid(i)) { 140 lowercase_name_[i] = absl::AsciiStrToLower( 141 PrimitiveType_Name(static_cast<PrimitiveType>(i))); 142 } 143 } 144 } LowercaseName(PrimitiveType t)145 const string& LowercaseName(PrimitiveType t) { 146 return lowercase_name_[static_cast<int>(t)]; 147 } 148 149 private: 150 string lowercase_name_[PrimitiveType_ARRAYSIZE]; 151 }; 152 LowercasePrimitiveTypeName(PrimitiveType s)153const string& LowercasePrimitiveTypeName(PrimitiveType s) { 154 static auto* gen = new PrimitiveTypeNameGenerator(); 155 return gen->LowercaseName(s); 156 } 157 158 namespace { 159 160 // Returns a map from lower-case primitive type name to primitive type. GetPrimitiveTypeStringMap()161const std::unordered_map<string, PrimitiveType>& GetPrimitiveTypeStringMap() { 162 static std::unordered_map<string, PrimitiveType>* name_to_type = [] { 163 static auto* map = new std::unordered_map<string, PrimitiveType>; 164 for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { 165 if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { 166 auto value = static_cast<PrimitiveType>(i); 167 (*map)[LowercasePrimitiveTypeName(value)] = value; 168 } 169 } 170 return map; 171 }(); 172 return *name_to_type; 173 } 174 175 } // namespace 176 StringToPrimitiveType(absl::string_view name)177StatusOr<PrimitiveType> StringToPrimitiveType(absl::string_view name) { 178 const auto& map = GetPrimitiveTypeStringMap(); 179 auto found = map.find(string(name)); 180 if (found == map.end()) { 181 return InvalidArgument("Invalid element type string: \"%s\".", name); 182 } 183 return found->second; 184 } 185 IsPrimitiveTypeName(absl::string_view name)186bool IsPrimitiveTypeName(absl::string_view name) { 187 const auto& map = GetPrimitiveTypeStringMap(); 188 auto found = map.find(string(name)); 189 return found != map.end(); 190 } 191 192 } // namespace primitive_util 193 } // namespace xla 194