1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
17
18 #include <limits>
19
20 #include "absl/base/casts.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Types.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_shape.pb.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/bfloat16.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/tstring.h"
44 #include "tensorflow/stream_executor/lib/statusor.h"
45
46 namespace tensorflow {
47
48 using llvm::ArrayRef;
49 using llvm::SmallVector;
50 using mlir::Builder;
51 using mlir::DenseFPElementsAttr;
52 using mlir::DenseIntElementsAttr;
53 using mlir::DenseStringElementsAttr;
54 using mlir::ElementsAttr;
55 using mlir::OpaqueElementsAttr;
56 using mlir::RankedTensorType;
57 using mlir::ShapedType;
58 using mlir::Type;
59 using tensorflow::errors::InvalidArgument;
60
ConvertToProto(const Tensor & input_tensor,bool use_tensor_content=true)61 static TensorProto ConvertToProto(const Tensor& input_tensor,
62 bool use_tensor_content = true) {
63 TensorProto tensor_proto;
64 // Using tensor content (mostly*) reduces serialization overhead during RPC
65 // calls, but is less human reader friendly. People reading protobufs are less
66 // frequent than serialization, so default to using tensor content
67 // representation.
68 // * For scalars and short strings it may be marginally worse and a more
69 // intelligent decision could be made by caller.
70 if (use_tensor_content)
71 input_tensor.AsProtoTensorContent(&tensor_proto);
72 else
73 input_tensor.AsProtoField(&tensor_proto);
74 return tensor_proto;
75 }
76
MangleTensor(const Tensor & tensor)77 static std::string MangleTensor(const Tensor& tensor) {
78 return mangling_util::MangleTensor(ConvertToProto(tensor));
79 }
80
81 // Converts a TensorFlow tensor into an MLIR elements attribute.
82 template <typename T>
ConvertFlatTensor(const Tensor & input_tensor,ShapedType type)83 StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
84 ShapedType type) {
85 auto arr = input_tensor.flat<T>();
86 return mlir::DenseElementsAttr::get(
87 type, llvm::makeArrayRef(arr.data(), arr.size()));
88 }
89
ConvertBf16Tensor(const Tensor & input_tensor,RankedTensorType type)90 ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
91 RankedTensorType type) {
92 auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
93 input_tensor.TotalBytes());
94 return mlir::DenseElementsAttr::getFromRawBuffer(
95 type, buffer,
96 /*isSplatBuffer=*/type.getNumElements() == 1);
97 }
98
ConvertHalfTensor(const Tensor & tensor,RankedTensorType type)99 ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
100 auto buffer = llvm::makeArrayRef(static_cast<char*>(tensor.data()),
101 tensor.TotalBytes());
102 return mlir::DenseElementsAttr::getFromRawBuffer(
103 type, buffer,
104 /*isSplatBuffer=*/type.getNumElements() == 1);
105 }
106
ConvertStringTensor(const Tensor & input_tensor,ShapedType type)107 StatusOr<ElementsAttr> ConvertStringTensor(const Tensor& input_tensor,
108 ShapedType type) {
109 // Extract to a vector of StringRefs for converting.
110 auto arr = input_tensor.flat<tstring>();
111 std::vector<mlir::StringRef> string_refs;
112 string_refs.reserve(arr.size());
113 for (int i = 0; i < arr.size(); i++) {
114 const auto& val = arr(i);
115 string_refs.push_back({val.data(), val.size()});
116 }
117
118 return DenseStringElementsAttr::get(type, string_refs);
119 }
120
ConvertTensor(const Tensor & input_tensor,Builder * builder)121 StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
122 Builder* builder) {
123 const auto& input_dtype = input_tensor.dtype();
124 const auto& input_shape = input_tensor.shape();
125 Type elt_type;
126 TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
127 SmallVector<int64_t, 4> shape;
128 ConvertToMlirShape(input_shape, &shape);
129 auto type = RankedTensorType::get(shape, elt_type);
130
131 #define CONVERT_FLAT(DTYPE, CTYPE) \
132 case DTYPE: \
133 return ConvertFlatTensor<CTYPE>(input_tensor, type);
134
135 // TODO(fengliuai): customize the conversions for quantized and string types.
136 switch (input_dtype) {
137 CONVERT_FLAT(DT_BOOL, bool)
138 CONVERT_FLAT(DT_FLOAT, float)
139 CONVERT_FLAT(DT_DOUBLE, double)
140 CONVERT_FLAT(DT_INT8, int8)
141 CONVERT_FLAT(DT_INT16, int16)
142 CONVERT_FLAT(DT_INT32, int32)
143 CONVERT_FLAT(DT_INT64, int64)
144 CONVERT_FLAT(DT_UINT8, uint8)
145 CONVERT_FLAT(DT_UINT16, uint16)
146 CONVERT_FLAT(DT_UINT32, uint32)
147 CONVERT_FLAT(DT_UINT64, uint64)
148 CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
149 CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
150
151 // BFLOAT16 is a special case that it needs to be cast to double type to
152 // match its storage type.
153 case DT_BFLOAT16:
154 return ConvertBf16Tensor(input_tensor, type);
155 case DT_HALF:
156 return ConvertHalfTensor(input_tensor, type);
157
158 case DT_STRING:
159 return ConvertStringTensor(input_tensor, type);
160
161 default:
162 // TODO(shpeisman): restructure code to reuse dialect pointer across
163 // calls.
164 auto* dialect = builder->getContext()->getLoadedDialect("tf");
165 return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
166 }
167
168 #undef CONVERT_FLAT
169 }
170
ConvertTensorProto(const TensorProto & input_tensor,Builder * builder)171 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
172 Builder* builder) {
173 Tensor t;
174 if (!t.FromProto(input_tensor))
175 return InvalidArgument("Failed to parse input_tensor.");
176 return ConvertTensor(t, builder);
177 }
178
ConvertToTensorShapeProto(ArrayRef<int64_t> shape,TensorShapeProto * output_shape)179 void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
180 TensorShapeProto* output_shape) {
181 for (auto d : shape) {
182 output_shape->add_dim()->set_size(d);
183 }
184 }
185
ConvertTypeToTensorShape(const mlir::Type & type)186 PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
187 if (type.isa<mlir::UnrankedTensorType>()) {
188 // An empty PartialTensorShape indicates an unranked tensor.
189 return PartialTensorShape();
190 }
191
192 if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
193 TensorShapeProto tensor_shape_proto;
194 ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
195 return PartialTensorShape(tensor_shape_proto);
196 }
197
198 // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
199 // Empty TensorShape indicates a scalar.
200 return TensorShape();
201 }
202
ConvertTypeToTensorShapeAttr(const mlir::Type & type)203 mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
204 if (type.isa<mlir::UnrankedTensorType>()) {
205 return mlir::TF::ShapeAttr::get(type.getContext(), llvm::None);
206 }
207
208 if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
209 return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape());
210 }
211
212 // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
213 // Empty TensorShape indicates a scalar.
214 return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef<int64_t>());
215 }
216
217 // Converts the tensor shape proto into an MLIR shape attribute.
ConvertTensorShapeProto(const TensorShapeProto & shape,mlir::MLIRContext * context)218 StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
219 mlir::MLIRContext* context) {
220 if (shape.unknown_rank())
221 return mlir::TF::ShapeAttr::get(context, llvm::None);
222
223 llvm::SmallVector<int64_t, 4> dims;
224 dims.reserve(shape.dim().size());
225 for (const auto& dim : shape.dim()) {
226 dims.push_back(dim.size());
227 }
228 return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims));
229 }
230
231 // Converts an MLIR dense string elements attribute to a TensorFlow tensor
232 // proto.
ConvertStringElementsAttr(const DenseStringElementsAttr attr,protobuf::RepeatedPtrField<std::string> * output)233 void ConvertStringElementsAttr(
234 const DenseStringElementsAttr attr,
235 protobuf::RepeatedPtrField<std::string>* output) {
236 for (const auto& val : attr.getRawStringData())
237 output->Add({val.data(), val.size()});
238 }
239
240 template <typename T>
ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)241 void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
242 protobuf::RepeatedField<T>* output) {
243 for (const auto& val : attr.getValues<std::complex<T>>()) {
244 output->Add(val.real());
245 output->Add(val.imag());
246 }
247 }
248
249 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
ConvertOpaqueElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)250 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
251 TensorProto* output_tensor) {
252 if (attr.isa<OpaqueElementsAttr>()) {
253 auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
254 absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
255 return mangling_util::DemangleTensor(tensor_view, output_tensor);
256 }
257 return InvalidArgument("Unexpected elements attribute type from MLIR.");
258 }
259
260 // Converts an MLIR elements attribute and adds it to specified repeated field.
261 template <typename T>
ConvertElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)262 void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
263 protobuf::RepeatedField<T>* output) {
264 if (attr.isSplat()) {
265 output->Add(attr.getSplatValue<T>());
266 } else {
267 output->Reserve(attr.getNumElements());
268 for (auto value : attr.getValues<T>()) output->AddAlreadyReserved(value);
269 }
270 }
271
272 // Converts an MLIR elements attribute containing half values and adds it to
273 // specified repeated field.
ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)274 void ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,
275 protobuf::RepeatedField<int>* output) {
276 if (attr.isSplat()) {
277 output->Add(attr.getSplatValue<Eigen::half>().x);
278 } else {
279 output->Reserve(attr.getNumElements());
280 for (const Eigen::half value : attr.getValues<Eigen::half>())
281 output->AddAlreadyReserved(value.x);
282 }
283 }
284
285 // Converts an MLIR elements attribute containing int values and adds it to
286 // specified repeated field.
ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,protobuf::RepeatedField<int> * output)287 void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
288 protobuf::RepeatedField<int>* output) {
289 if (attr.isSplat()) {
290 output->Add((*attr.begin()).getSExtValue());
291 } else {
292 output->Reserve(attr.getNumElements());
293 for (const llvm::APInt val : attr)
294 output->AddAlreadyReserved(val.getSExtValue());
295 }
296 }
297
ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)298 void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,
299 protobuf::RepeatedField<int>* output) {
300 if (attr.isSplat()) {
301 output->Add(attr.getSplatValue<bfloat16>().value);
302 } else {
303 output->Reserve(attr.getNumElements());
304 for (const bfloat16 value : attr.getValues<bfloat16>())
305 output->AddAlreadyReserved(value.value);
306 }
307 }
308
ConvertToTensorProto(const ElementsAttr attr,TensorProto * output)309 Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
310 auto type = attr.getType();
311 auto shape = type.getShape();
312 DataType output_dtype;
313 TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
314 output->set_dtype(output_dtype);
315 ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
316
317 if (attr.isa<OpaqueElementsAttr>())
318 return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
319
320 auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
321 if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
322
323 switch (output_dtype) {
324 case DT_FLOAT:
325 ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
326 break;
327 case DT_HALF:
328 ConvertHalfElementsAttr(dense_attr, output->mutable_half_val());
329 break;
330 case DT_DOUBLE:
331 ConvertElementsAttr(dense_attr, output->mutable_double_val());
332 break;
333 case DT_QUINT8:
334 case DT_UINT8:
335 case DT_INT8:
336 case DT_QUINT16:
337 case DT_UINT16:
338 case DT_INT16:
339 case DT_INT32:
340 ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
341 output->mutable_int_val());
342 break;
343 case DT_UINT32:
344 ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
345 break;
346 case DT_UINT64:
347 ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
348 break;
349 case DT_INT64:
350 ConvertElementsAttr(dense_attr, output->mutable_int64_val());
351 break;
352 case DT_BOOL:
353 ConvertElementsAttr(dense_attr, output->mutable_bool_val());
354 break;
355 case DT_BFLOAT16:
356 ConvertBfloat16ElementsAttr(dense_attr, output->mutable_half_val());
357 break;
358 case DT_STRING:
359 ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
360 output->mutable_string_val());
361 break;
362 case DT_COMPLEX64:
363 ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
364 break;
365 case DT_COMPLEX128:
366 ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
367 break;
368 default:
369 return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
370 DataTypeString(output_dtype)));
371 }
372 return Status::OK();
373 }
374
ConvertToTensor(const mlir::ElementsAttr attr,Tensor * output_tensor)375 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
376 TensorProto tensor_proto;
377 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto));
378 if (!output_tensor->FromProto(tensor_proto)) {
379 return InvalidArgument("Couldn't convert tensor proto to tensor.");
380 }
381 return Status::OK();
382 }
383
DecodeOpaqueTensor(const mlir::OpaqueElementsAttr input_attr,mlir::Builder builder)384 StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
385 const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
386 // TODO(antiagainst): The following logic, albeit simple, involves copying the
387 // tensor content multiple times, which is bad. Figure out a better way to
388 // achieve the purpose.
389 Tensor tensor;
390 TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
391 return ConvertTensor(tensor, &builder);
392 }
393
394 } // namespace tensorflow
395