• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/types.h"
17 #include "tensorflow/core/framework/register_types.h"
18 
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace tensorflow {
24 
operator <(const DeviceType & other) const25 bool DeviceType::operator<(const DeviceType& other) const {
26   return type_ < other.type_;
27 }
28 
operator ==(const DeviceType & other) const29 bool DeviceType::operator==(const DeviceType& other) const {
30   return type_ == other.type_;
31 }
32 
operator <<(std::ostream & os,const DeviceType & d)33 std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
34   os << d.type();
35   return os;
36 }
37 
38 const char* const DEVICE_DEFAULT = "DEFAULT";
39 const char* const DEVICE_CPU = "CPU";
40 const char* const DEVICE_GPU = "GPU";
41 const char* const DEVICE_TPU_SYSTEM = "TPU_SYSTEM";
42 
43 const std::string DeviceName<Eigen::ThreadPoolDevice>::value = DEVICE_CPU;
44 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
45     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
46 const std::string DeviceName<Eigen::GpuDevice>::value = DEVICE_GPU;
47 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48 
49 namespace {
DataTypeStringInternal(DataType dtype)50 string DataTypeStringInternal(DataType dtype) {
51   switch (dtype) {
52     case DT_INVALID:
53       return "INVALID";
54     case DT_FLOAT:
55       return "float";
56     case DT_DOUBLE:
57       return "double";
58     case DT_INT32:
59       return "int32";
60     case DT_UINT32:
61       return "uint32";
62     case DT_UINT8:
63       return "uint8";
64     case DT_UINT16:
65       return "uint16";
66     case DT_INT16:
67       return "int16";
68     case DT_INT8:
69       return "int8";
70     case DT_STRING:
71       return "string";
72     case DT_COMPLEX64:
73       return "complex64";
74     case DT_COMPLEX128:
75       return "complex128";
76     case DT_INT64:
77       return "int64";
78     case DT_UINT64:
79       return "uint64";
80     case DT_BOOL:
81       return "bool";
82     case DT_QINT8:
83       return "qint8";
84     case DT_QUINT8:
85       return "quint8";
86     case DT_QUINT16:
87       return "quint16";
88     case DT_QINT16:
89       return "qint16";
90     case DT_QINT32:
91       return "qint32";
92     case DT_BFLOAT16:
93       return "bfloat16";
94     case DT_HALF:
95       return "half";
96     case DT_RESOURCE:
97       return "resource";
98     case DT_VARIANT:
99       return "variant";
100     default:
101       LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
102       return strings::StrCat("unknown dtype enum (", dtype, ")");
103   }
104 }
105 }  // end namespace
106 
DataTypeString(DataType dtype)107 string DataTypeString(DataType dtype) {
108   if (IsRefType(dtype)) {
109     DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
110     return strings::StrCat(DataTypeStringInternal(non_ref), "_ref");
111   }
112   return DataTypeStringInternal(dtype);
113 }
114 
DataTypeFromString(StringPiece sp,DataType * dt)115 bool DataTypeFromString(StringPiece sp, DataType* dt) {
116   if (str_util::EndsWith(sp, "_ref")) {
117     sp.remove_suffix(4);
118     DataType non_ref;
119     if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
120       *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
121       return true;
122     } else {
123       return false;
124     }
125   }
126 
127   if (sp == "float" || sp == "float32") {
128     *dt = DT_FLOAT;
129     return true;
130   } else if (sp == "double" || sp == "float64") {
131     *dt = DT_DOUBLE;
132     return true;
133   } else if (sp == "int32") {
134     *dt = DT_INT32;
135     return true;
136   } else if (sp == "uint32") {
137     *dt = DT_UINT32;
138     return true;
139   } else if (sp == "uint8") {
140     *dt = DT_UINT8;
141     return true;
142   } else if (sp == "uint16") {
143     *dt = DT_UINT16;
144     return true;
145   } else if (sp == "int16") {
146     *dt = DT_INT16;
147     return true;
148   } else if (sp == "int8") {
149     *dt = DT_INT8;
150     return true;
151   } else if (sp == "string") {
152     *dt = DT_STRING;
153     return true;
154   } else if (sp == "complex64") {
155     *dt = DT_COMPLEX64;
156     return true;
157   } else if (sp == "complex128") {
158     *dt = DT_COMPLEX128;
159     return true;
160   } else if (sp == "int64") {
161     *dt = DT_INT64;
162     return true;
163   } else if (sp == "uint64") {
164     *dt = DT_UINT64;
165     return true;
166   } else if (sp == "bool") {
167     *dt = DT_BOOL;
168     return true;
169   } else if (sp == "qint8") {
170     *dt = DT_QINT8;
171     return true;
172   } else if (sp == "quint8") {
173     *dt = DT_QUINT8;
174     return true;
175   } else if (sp == "qint16") {
176     *dt = DT_QINT16;
177     return true;
178   } else if (sp == "quint16") {
179     *dt = DT_QUINT16;
180     return true;
181   } else if (sp == "qint32") {
182     *dt = DT_QINT32;
183     return true;
184   } else if (sp == "bfloat16") {
185     *dt = DT_BFLOAT16;
186     return true;
187   } else if (sp == "half" || sp == "float16") {
188     *dt = DT_HALF;
189     return true;
190   } else if (sp == "resource") {
191     *dt = DT_RESOURCE;
192     return true;
193   } else if (sp == "variant") {
194     *dt = DT_VARIANT;
195     return true;
196   }
197   return false;
198 }
199 
DeviceTypeString(const DeviceType & device_type)200 string DeviceTypeString(const DeviceType& device_type) {
201   return device_type.type();
202 }
203 
DataTypeSliceString(const DataTypeSlice types)204 string DataTypeSliceString(const DataTypeSlice types) {
205   string out;
206   for (auto it = types.begin(); it != types.end(); ++it) {
207     strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
208                        DataTypeString(*it));
209   }
210   return out;
211 }
212 
DataTypeAlwaysOnHost(DataType dt)213 bool DataTypeAlwaysOnHost(DataType dt) {
214   // Includes DT_STRING and DT_RESOURCE.
215   switch (dt) {
216     case DT_STRING:
217     case DT_STRING_REF:
218     case DT_RESOURCE:
219       return true;
220     default:
221       return false;
222   }
223 }
224 
DataTypeSize(DataType dt)225 int DataTypeSize(DataType dt) {
226 #define CASE(T)                  \
227   case DataTypeToEnum<T>::value: \
228     return sizeof(T);
229   switch (dt) {
230     TF_CALL_POD_TYPES(CASE);
231     TF_CALL_QUANTIZED_TYPES(CASE);
232     // TF_CALL_QUANTIZED_TYPES() macro does no cover quint16 and qint16, since
233     // they are not supported widely, but are explicitly listed here for
234     // bitcast.
235     TF_CALL_qint16(CASE);
236     TF_CALL_quint16(CASE);
237 
238     default:
239       return 0;
240   }
241 #undef CASE
242 }
243 
244 // Define DataTypeToEnum<T>::value.
245 #define DEFINE_DATATYPETOENUM_VALUE(TYPE) \
246   constexpr DataType DataTypeToEnum<TYPE>::value;
247 
248 DEFINE_DATATYPETOENUM_VALUE(float);
249 DEFINE_DATATYPETOENUM_VALUE(double);
250 DEFINE_DATATYPETOENUM_VALUE(int32);
251 DEFINE_DATATYPETOENUM_VALUE(uint32);
252 DEFINE_DATATYPETOENUM_VALUE(uint16);
253 DEFINE_DATATYPETOENUM_VALUE(uint8);
254 DEFINE_DATATYPETOENUM_VALUE(int16);
255 DEFINE_DATATYPETOENUM_VALUE(int8);
256 DEFINE_DATATYPETOENUM_VALUE(tstring);
257 DEFINE_DATATYPETOENUM_VALUE(complex64);
258 DEFINE_DATATYPETOENUM_VALUE(complex128);
259 DEFINE_DATATYPETOENUM_VALUE(int64);
260 DEFINE_DATATYPETOENUM_VALUE(uint64);
261 DEFINE_DATATYPETOENUM_VALUE(bool);
262 DEFINE_DATATYPETOENUM_VALUE(qint8);
263 DEFINE_DATATYPETOENUM_VALUE(quint8);
264 DEFINE_DATATYPETOENUM_VALUE(qint16);
265 DEFINE_DATATYPETOENUM_VALUE(quint16);
266 DEFINE_DATATYPETOENUM_VALUE(qint32);
267 DEFINE_DATATYPETOENUM_VALUE(bfloat16);
268 DEFINE_DATATYPETOENUM_VALUE(Eigen::half);
269 DEFINE_DATATYPETOENUM_VALUE(ResourceHandle);
270 DEFINE_DATATYPETOENUM_VALUE(Variant);
271 #undef DEFINE_DATATYPETOENUM_VALUE
272 
273 }  // namespace tensorflow
274