1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
17
18 #include <cstdint>
19 #include <limits>
20
21 #include "absl/base/casts.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_shape.pb.h"
38 #include "tensorflow/core/framework/tensor_util.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/platform/bfloat16.h"
42 #include "tensorflow/core/platform/errors.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/platform/tstring.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46
47 namespace tensorflow {
48
49 using llvm::ArrayRef;
50 using llvm::SmallVector;
51 using mlir::Builder;
52 using mlir::DenseStringElementsAttr;
53 using mlir::ElementsAttr;
54 using mlir::OpaqueElementsAttr;
55 using mlir::RankedTensorType;
56 using mlir::ShapedType;
57 using mlir::Type;
58 using tensorflow::errors::InvalidArgument;
59
ConvertToProto(const Tensor & input_tensor,bool use_tensor_content=true)60 static TensorProto ConvertToProto(const Tensor& input_tensor,
61 bool use_tensor_content = true) {
62 TensorProto tensor_proto;
63 // Using tensor content (mostly*) reduces serialization overhead during RPC
64 // calls, but is less human reader friendly. People reading protobufs are less
65 // frequent than serialization, so default to using tensor content
66 // representation.
67 // * For scalars and short strings it may be marginally worse and a more
68 // intelligent decision could be made by caller.
69 if (use_tensor_content)
70 input_tensor.AsProtoTensorContent(&tensor_proto);
71 else
72 input_tensor.AsProtoField(&tensor_proto);
73 return tensor_proto;
74 }
75
MangleTensor(const Tensor & tensor)76 static std::string MangleTensor(const Tensor& tensor) {
77 return mangling_util::MangleTensor(ConvertToProto(tensor));
78 }
79
80 // Converts a TensorFlow tensor into an MLIR elements attribute.
81 template <typename T>
ConvertFlatTensor(const Tensor & input_tensor,ShapedType type)82 StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
83 ShapedType type) {
84 auto arr = input_tensor.flat<T>();
85 return mlir::DenseElementsAttr::get(
86 type, llvm::makeArrayRef(arr.data(), arr.size()));
87 }
88
ConvertBf16Tensor(const Tensor & input_tensor,RankedTensorType type)89 ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
90 RankedTensorType type) {
91 auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
92 input_tensor.TotalBytes());
93 return mlir::DenseElementsAttr::getFromRawBuffer(
94 type, buffer,
95 /*isSplatBuffer=*/type.getNumElements() == 1);
96 }
97
ConvertHalfTensor(const Tensor & tensor,RankedTensorType type)98 ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
99 auto buffer = llvm::makeArrayRef(static_cast<char*>(tensor.data()),
100 tensor.TotalBytes());
101 return mlir::DenseElementsAttr::getFromRawBuffer(
102 type, buffer,
103 /*isSplatBuffer=*/type.getNumElements() == 1);
104 }
105
ConvertStringTensor(const Tensor & input_tensor,ShapedType type)106 StatusOr<ElementsAttr> ConvertStringTensor(const Tensor& input_tensor,
107 ShapedType type) {
108 // Extract to a vector of StringRefs for converting.
109 auto arr = input_tensor.flat<tstring>();
110 std::vector<mlir::StringRef> string_refs;
111 string_refs.reserve(arr.size());
112 for (int i = 0; i < arr.size(); i++) {
113 const auto& val = arr(i);
114 string_refs.push_back({val.data(), val.size()});
115 }
116
117 return DenseStringElementsAttr::get(type, string_refs);
118 }
119
ConvertTensor(const Tensor & input_tensor,Builder * builder)120 StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
121 Builder* builder) {
122 const auto& input_dtype = input_tensor.dtype();
123 const auto& input_shape = input_tensor.shape();
124 Type elt_type;
125 TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
126 SmallVector<int64_t, 4> shape;
127 ConvertToMlirShape(input_shape, &shape);
128 auto type = RankedTensorType::get(shape, elt_type);
129
130 #define CONVERT_FLAT(DTYPE, CTYPE) \
131 case DTYPE: \
132 return ConvertFlatTensor<CTYPE>(input_tensor, type);
133
134 // TODO(fengliuai): customize the conversions for quantized and string types.
135 switch (input_dtype) {
136 CONVERT_FLAT(DT_BOOL, bool)
137 CONVERT_FLAT(DT_FLOAT, float)
138 CONVERT_FLAT(DT_DOUBLE, double)
139 CONVERT_FLAT(DT_INT8, int8)
140 CONVERT_FLAT(DT_INT16, int16)
141 CONVERT_FLAT(DT_INT32, int32)
142 CONVERT_FLAT(DT_INT64, int64)
143 CONVERT_FLAT(DT_UINT8, uint8)
144 CONVERT_FLAT(DT_UINT16, uint16)
145 CONVERT_FLAT(DT_UINT32, uint32)
146 CONVERT_FLAT(DT_UINT64, uint64)
147 CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
148 CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
149
150 // BFLOAT16 is a special case that it needs to be cast to double type to
151 // match its storage type.
152 case DT_BFLOAT16:
153 return ConvertBf16Tensor(input_tensor, type);
154 case DT_HALF:
155 return ConvertHalfTensor(input_tensor, type);
156
157 case DT_STRING:
158 return ConvertStringTensor(input_tensor, type);
159
160 default:
161 // TODO(shpeisman): restructure code to reuse dialect pointer across
162 // calls.
163 auto* dialect = builder->getContext()->getLoadedDialect("tf");
164 return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
165 }
166
167 #undef CONVERT_FLAT
168 }
169
170 // Returns the number of elements present in this TensorProto, or -1 if that
171 // could not be determined. This might be less than the shape of the proto might
172 // indicate, if we're storing a splat tensor.
NumberOfMaterializedElements(const TensorProto & tensor)173 int NumberOfMaterializedElements(const TensorProto& tensor) {
174 if (!tensor.tensor_content().empty()) return -1;
175 // We don't know which element type this protocol buffer is storing, and the
176 // metaprogramming facilities for TensorProto are too limited to check their
177 // number without knowing this, so we need to manually dispatch to each
178 // possible member of TensorProto, depening on its dtype.
179 #define MATCH(DTYPE, FIELD) \
180 case DTYPE: \
181 return tensor.FIELD##_val().size()
182
183 switch (tensor.dtype()) {
184 MATCH(DT_FLOAT, float);
185 MATCH(DT_DOUBLE, double);
186 MATCH(DT_INT8, int);
187 MATCH(DT_UINT8, int);
188 MATCH(DT_INT16, int);
189 MATCH(DT_UINT16, int);
190 MATCH(DT_INT32, int);
191 MATCH(DT_UINT32, uint32);
192 MATCH(DT_INT64, int64);
193 MATCH(DT_UINT64, uint64);
194 MATCH(DT_BOOL, bool);
195 MATCH(DT_HALF, half);
196 MATCH(DT_BFLOAT16, half);
197 MATCH(DT_STRING, string);
198
199 // TODO(b/188995810): DenseElementsAttr::get doesn't support complex
200 // Attributes being passed, so we bail out for now. This should just be
201 // MATCH(DT_COMPLEX64, scomplex) / 2;
202 // MATCH(DT_COMPLEX128, dcomplex) / 2;
203 // when DenseElementsAttr is updated.
204 case DT_COMPLEX64:
205 case DT_COMPLEX128:
206 default:
207 return -1;
208 }
209 }
210
ConvertTensorProto(const TensorProto & input_tensor,Builder * builder)211 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
212 Builder* builder) {
213 // If there is only one actual element in the proto, but its shape would
214 // indicate there are more values, then this is representing a splat tensor.
215 // We can create an MLIR Attribute more efficiently in this case.
216 TensorShape input_tensor_shape(input_tensor.tensor_shape());
217 if (NumberOfMaterializedElements(input_tensor) == 1 &&
218 input_tensor_shape.num_elements() > 1) {
219 // We first convert this TensorProto to one of shape [1]. We then create an
220 // Attribute for that proto, and finally splat the Attribute.
221
222 TensorProto tensor_copy = input_tensor;
223 auto* shape = tensor_copy.mutable_tensor_shape();
224 shape->clear_dim();
225 shape->add_dim()->set_size(1);
226
227 TF_ASSIGN_OR_RETURN(ElementsAttr single_attr,
228 ConvertTensorProto(tensor_copy, builder));
229
230 std::vector<int64_t> original_dimensions;
231 for (auto dim : input_tensor_shape) original_dimensions.push_back(dim.size);
232 return mlir::SplatElementsAttr::get(
233 single_attr.getType().clone(original_dimensions),
234 single_attr.getValue({0}));
235 }
236
237 Tensor t;
238 if (!t.FromProto(input_tensor))
239 return InvalidArgument("Failed to parse input_tensor.");
240 return ConvertTensor(t, builder);
241 }
242
ConvertToTensorShapeProto(ArrayRef<int64_t> shape,TensorShapeProto * output_shape)243 void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
244 TensorShapeProto* output_shape) {
245 for (auto d : shape) {
246 output_shape->add_dim()->set_size(d);
247 }
248 }
249
ConvertTypeToTensorShape(const mlir::Type & type)250 PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
251 if (type.isa<mlir::UnrankedTensorType>()) {
252 // An empty PartialTensorShape indicates an unranked tensor.
253 return PartialTensorShape();
254 }
255
256 if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
257 TensorShapeProto tensor_shape_proto;
258 ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
259 return PartialTensorShape(tensor_shape_proto);
260 }
261
262 // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
263 // Empty TensorShape indicates a scalar.
264 return TensorShape();
265 }
266
ConvertTypeToTensorShapeAttr(const mlir::Type & type)267 mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
268 if (type.isa<mlir::UnrankedTensorType>()) {
269 return mlir::TF::ShapeAttr::get(type.getContext(), llvm::None);
270 }
271
272 if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
273 return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape());
274 }
275
276 // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
277 // Empty TensorShape indicates a scalar.
278 return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef<int64_t>());
279 }
280
281 // Converts the tensor shape proto into an MLIR shape attribute.
ConvertTensorShapeProto(const TensorShapeProto & shape,mlir::MLIRContext * context)282 StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
283 mlir::MLIRContext* context) {
284 if (shape.unknown_rank())
285 return mlir::TF::ShapeAttr::get(context, llvm::None);
286
287 llvm::SmallVector<int64_t, 4> dims;
288 dims.reserve(shape.dim().size());
289 for (const auto& dim : shape.dim()) {
290 dims.push_back(dim.size());
291 }
292 return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims));
293 }
294
295 // Converts an MLIR dense string elements attribute to a TensorFlow tensor
296 // proto.
ConvertStringElementsAttr(const DenseStringElementsAttr attr,protobuf::RepeatedPtrField<std::string> * output)297 void ConvertStringElementsAttr(
298 const DenseStringElementsAttr attr,
299 protobuf::RepeatedPtrField<std::string>* output) {
300 for (const auto& val : attr.getRawStringData())
301 output->Add({val.data(), val.size()});
302 }
303
304 template <typename T>
ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)305 void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
306 protobuf::RepeatedField<T>* output) {
307 for (const auto& val : attr.getValues<std::complex<T>>()) {
308 output->Add(val.real());
309 output->Add(val.imag());
310 }
311 }
312
313 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
ConvertOpaqueElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)314 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
315 TensorProto* output_tensor) {
316 if (attr.isa<OpaqueElementsAttr>()) {
317 auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
318 absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
319 return mangling_util::DemangleTensor(tensor_view, output_tensor);
320 }
321 return InvalidArgument("Unexpected elements attribute type from MLIR.");
322 }
323
324 template <typename T>
ConvertElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)325 void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
326 protobuf::RepeatedField<T>* output) {
327 if (attr.isSplat()) {
328 if (attr.getSplatValue<T>() != T(0)) output->Add(attr.getSplatValue<T>());
329 } else {
330 output->Reserve(attr.getNumElements());
331 for (auto value : attr.getValues<T>()) output->AddAlreadyReserved(value);
332 }
333 }
334
335 // Converts an MLIR elements attribute and adds it to specified repeated field.
336 template <typename T, typename Cord>
ConvertFloatElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output,Cord * tensor_content)337 void ConvertFloatElementsAttr(const mlir::DenseElementsAttr attr,
338 protobuf::RepeatedField<T>* output,
339 Cord* tensor_content) {
340 if (attr.isSplat()) {
341 if (attr.getSplatValue<T>() != T(0)) output->Add(attr.getSplatValue<T>());
342 } else {
343 port::CopyFromArray(tensor_content, attr.getRawData().data(),
344 attr.getRawData().size());
345 }
346 }
347
348 // Converts an MLIR elements attribute containing half values and adds it to
349 // specified repeated field.
ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)350 void ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,
351 protobuf::RepeatedField<int>* output) {
352 if (attr.isSplat()) {
353 if (attr.getSplatValue<Eigen::half>() != Eigen::half(0))
354 output->Add(
355 Eigen::numext::bit_cast<uint16_t>(attr.getSplatValue<Eigen::half>()));
356 } else {
357 output->Reserve(attr.getNumElements());
358 for (const Eigen::half value : attr.getValues<Eigen::half>())
359 output->AddAlreadyReserved(Eigen::numext::bit_cast<uint16_t>(value));
360 }
361 }
362
363 // Converts an MLIR elements attribute containing signed int values and adds it
364 // to specified repeated field.
365 template <typename T, typename U = T, typename Cord>
ConvertIntElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output,Cord * tensor_content)366 void ConvertIntElementsAttr(const mlir::DenseElementsAttr attr,
367 protobuf::RepeatedField<T>* output,
368 Cord* tensor_content) {
369 if (attr.isSplat()) {
370 if (attr.getSplatValue<U>() != U(0)) output->Add(attr.getSplatValue<U>());
371 } else {
372 port::CopyFromArray(tensor_content, attr.getRawData().data(),
373 attr.getRawData().size());
374 }
375 }
376
377 // Converts an MLIR elements attribute containing unsigned int values and adds
378 // it to specified repeated field.
379 template <typename T, typename U = T, typename Cord>
ConvertUIntElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output,Cord * tensor_content)380 void ConvertUIntElementsAttr(const mlir::DenseElementsAttr attr,
381 protobuf::RepeatedField<T>* output,
382 Cord* tensor_content) {
383 if (attr.isSplat()) {
384 if (attr.getSplatValue<U>() != U(0)) output->Add(attr.getSplatValue<U>());
385 } else {
386 port::CopyFromArray(tensor_content, attr.getRawData().data(),
387 attr.getRawData().size());
388 }
389 }
390
ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)391 void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,
392 protobuf::RepeatedField<int>* output) {
393 if (attr.isSplat()) {
394 if (attr.getSplatValue<bfloat16>() != bfloat16(0))
395 output->Add(
396 Eigen::numext::bit_cast<uint16_t>(attr.getSplatValue<bfloat16>()));
397 } else {
398 output->Reserve(attr.getNumElements());
399 for (const bfloat16 value : attr.getValues<bfloat16>())
400 output->AddAlreadyReserved(Eigen::numext::bit_cast<uint16_t>(value));
401 }
402 }
403
ConvertToTensorProto(const ElementsAttr attr,TensorProto * output)404 Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
405 auto type = attr.getType();
406 auto shape = type.getShape();
407 DataType output_dtype;
408 TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
409 output->set_dtype(output_dtype);
410 ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
411
412 if (attr.isa<OpaqueElementsAttr>())
413 return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
414
415 auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
416 if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
417
418 switch (output_dtype) {
419 case DT_BOOL:
420 ConvertElementsAttr(dense_attr, output->mutable_bool_val());
421 break;
422 case DT_BFLOAT16:
423 ConvertBfloat16ElementsAttr(dense_attr, output->mutable_half_val());
424 break;
425 case DT_COMPLEX64:
426 ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
427 break;
428 case DT_COMPLEX128:
429 ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
430 break;
431 case DT_DOUBLE:
432 ConvertFloatElementsAttr(dense_attr, output->mutable_double_val(),
433 output->mutable_tensor_content());
434 break;
435 case DT_HALF:
436 ConvertHalfElementsAttr(dense_attr, output->mutable_half_val());
437 break;
438 case DT_FLOAT:
439 ConvertFloatElementsAttr(dense_attr, output->mutable_float_val(),
440 output->mutable_tensor_content());
441 break;
442 case DT_QUINT8:
443 case DT_INT8:
444 ConvertUIntElementsAttr<int, int8_t>(dense_attr,
445 output->mutable_int_val(),
446 output->mutable_tensor_content());
447 break;
448 case DT_QUINT16:
449 case DT_INT16:
450 ConvertIntElementsAttr<int, int16_t>(dense_attr,
451 output->mutable_int_val(),
452 output->mutable_tensor_content());
453 break;
454 case DT_INT32:
455 ConvertIntElementsAttr(dense_attr, output->mutable_int_val(),
456 output->mutable_tensor_content());
457 break;
458 case DT_INT64:
459 ConvertIntElementsAttr(dense_attr, output->mutable_int64_val(),
460 output->mutable_tensor_content());
461 break;
462 case DT_STRING:
463 ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
464 output->mutable_string_val());
465 break;
466 case DT_UINT8:
467 ConvertUIntElementsAttr<int, uint8_t>(dense_attr,
468 output->mutable_int_val(),
469 output->mutable_tensor_content());
470 break;
471 case DT_UINT16:
472 ConvertUIntElementsAttr<int, uint16_t>(dense_attr,
473 output->mutable_int_val(),
474 output->mutable_tensor_content());
475 break;
476 case DT_UINT32:
477 ConvertUIntElementsAttr(dense_attr, output->mutable_uint32_val(),
478 output->mutable_tensor_content());
479 break;
480 case DT_UINT64:
481 ConvertUIntElementsAttr(dense_attr, output->mutable_uint64_val(),
482 output->mutable_tensor_content());
483 break;
484 default:
485 return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
486 DataTypeString(output_dtype)));
487 }
488 return Status::OK();
489 }
490
ConvertToTensor(const mlir::ElementsAttr attr,Tensor * output_tensor)491 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
492 TensorProto tensor_proto;
493 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto));
494 if (!output_tensor->FromProto(tensor_proto)) {
495 return InvalidArgument("Couldn't convert tensor proto to tensor.");
496 }
497 return Status::OK();
498 }
499
DecodeOpaqueTensor(const mlir::OpaqueElementsAttr input_attr,mlir::Builder builder)500 StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
501 const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
502 // TODO(antiagainst): The following logic, albeit simple, involves copying the
503 // tensor content multiple times, which is bad. Figure out a better way to
504 // achieve the purpose.
505 Tensor tensor;
506 TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
507 return ConvertTensor(tensor, &builder);
508 }
509
510 } // namespace tensorflow
511