• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
17 
18 #include <limits>
19 
20 #include "absl/base/casts.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Types.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_shape.pb.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/bfloat16.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/tstring.h"
44 #include "tensorflow/stream_executor/lib/statusor.h"
45 
46 namespace tensorflow {
47 
48 using llvm::ArrayRef;
49 using llvm::SmallVector;
50 using mlir::Builder;
51 using mlir::DenseFPElementsAttr;
52 using mlir::DenseIntElementsAttr;
53 using mlir::DenseStringElementsAttr;
54 using mlir::ElementsAttr;
55 using mlir::OpaqueElementsAttr;
56 using mlir::RankedTensorType;
57 using mlir::ShapedType;
58 using mlir::Type;
59 using tensorflow::errors::InvalidArgument;
60 
ConvertToProto(const Tensor & input_tensor,bool use_tensor_content=true)61 static TensorProto ConvertToProto(const Tensor& input_tensor,
62                                   bool use_tensor_content = true) {
63   TensorProto tensor_proto;
64   // Using tensor content (mostly*) reduces serialization overhead during RPC
65   // calls, but is less human reader friendly. People reading protobufs are less
66   // frequent than serialization, so default to using tensor content
67   // representation.
68   // * For scalars and short strings it may be marginally worse and a more
69   //   intelligent decision could be made by caller.
70   if (use_tensor_content)
71     input_tensor.AsProtoTensorContent(&tensor_proto);
72   else
73     input_tensor.AsProtoField(&tensor_proto);
74   return tensor_proto;
75 }
76 
MangleTensor(const Tensor & tensor)77 static std::string MangleTensor(const Tensor& tensor) {
78   return mangling_util::MangleTensor(ConvertToProto(tensor));
79 }
80 
81 // Converts a TensorFlow tensor into an MLIR elements attribute.
82 template <typename T>
ConvertFlatTensor(const Tensor & input_tensor,ShapedType type)83 StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
84                                          ShapedType type) {
85   auto arr = input_tensor.flat<T>();
86   return mlir::DenseElementsAttr::get(
87       type, llvm::makeArrayRef(arr.data(), arr.size()));
88 }
89 
ConvertBf16Tensor(const Tensor & input_tensor,RankedTensorType type)90 ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
91                                RankedTensorType type) {
92   auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
93                                    input_tensor.TotalBytes());
94   return mlir::DenseElementsAttr::getFromRawBuffer(
95       type, buffer,
96       /*isSplatBuffer=*/type.getNumElements() == 1);
97 }
98 
ConvertHalfTensor(const Tensor & tensor,RankedTensorType type)99 ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
100   auto buffer = llvm::makeArrayRef(static_cast<char*>(tensor.data()),
101                                    tensor.TotalBytes());
102   return mlir::DenseElementsAttr::getFromRawBuffer(
103       type, buffer,
104       /*isSplatBuffer=*/type.getNumElements() == 1);
105 }
106 
ConvertStringTensor(const Tensor & input_tensor,ShapedType type)107 StatusOr<ElementsAttr> ConvertStringTensor(const Tensor& input_tensor,
108                                            ShapedType type) {
109   // Extract to a vector of StringRefs for converting.
110   auto arr = input_tensor.flat<tstring>();
111   std::vector<mlir::StringRef> string_refs;
112   string_refs.reserve(arr.size());
113   for (int i = 0; i < arr.size(); i++) {
114     const auto& val = arr(i);
115     string_refs.push_back({val.data(), val.size()});
116   }
117 
118   return DenseStringElementsAttr::get(type, string_refs);
119 }
120 
ConvertTensor(const Tensor & input_tensor,Builder * builder)121 StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
122                                      Builder* builder) {
123   const auto& input_dtype = input_tensor.dtype();
124   const auto& input_shape = input_tensor.shape();
125   Type elt_type;
126   TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
127   SmallVector<int64_t, 4> shape;
128   ConvertToMlirShape(input_shape, &shape);
129   auto type = RankedTensorType::get(shape, elt_type);
130 
131 #define CONVERT_FLAT(DTYPE, CTYPE) \
132   case DTYPE:                      \
133     return ConvertFlatTensor<CTYPE>(input_tensor, type);
134 
135   // TODO(fengliuai): customize the conversions for quantized and string types.
136   switch (input_dtype) {
137     CONVERT_FLAT(DT_BOOL, bool)
138     CONVERT_FLAT(DT_FLOAT, float)
139     CONVERT_FLAT(DT_DOUBLE, double)
140     CONVERT_FLAT(DT_INT8, int8)
141     CONVERT_FLAT(DT_INT16, int16)
142     CONVERT_FLAT(DT_INT32, int32)
143     CONVERT_FLAT(DT_INT64, int64)
144     CONVERT_FLAT(DT_UINT8, uint8)
145     CONVERT_FLAT(DT_UINT16, uint16)
146     CONVERT_FLAT(DT_UINT32, uint32)
147     CONVERT_FLAT(DT_UINT64, uint64)
148     CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
149     CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
150 
151     // BFLOAT16 is a special case that it needs to be cast to double type to
152     // match its storage type.
153     case DT_BFLOAT16:
154       return ConvertBf16Tensor(input_tensor, type);
155     case DT_HALF:
156       return ConvertHalfTensor(input_tensor, type);
157 
158     case DT_STRING:
159       return ConvertStringTensor(input_tensor, type);
160 
161     default:
162       // TODO(shpeisman): restructure code to reuse dialect pointer across
163       // calls.
164       auto* dialect = builder->getContext()->getLoadedDialect("tf");
165       return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
166   }
167 
168 #undef CONVERT_FLAT
169 }
170 
ConvertTensorProto(const TensorProto & input_tensor,Builder * builder)171 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
172                                           Builder* builder) {
173   Tensor t;
174   if (!t.FromProto(input_tensor))
175     return InvalidArgument("Failed to parse input_tensor.");
176   return ConvertTensor(t, builder);
177 }
178 
ConvertToTensorShapeProto(ArrayRef<int64_t> shape,TensorShapeProto * output_shape)179 void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
180                                TensorShapeProto* output_shape) {
181   for (auto d : shape) {
182     output_shape->add_dim()->set_size(d);
183   }
184 }
185 
ConvertTypeToTensorShape(const mlir::Type & type)186 PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
187   if (type.isa<mlir::UnrankedTensorType>()) {
188     // An empty PartialTensorShape indicates an unranked tensor.
189     return PartialTensorShape();
190   }
191 
192   if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
193     TensorShapeProto tensor_shape_proto;
194     ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
195     return PartialTensorShape(tensor_shape_proto);
196   }
197 
198   // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
199   // Empty TensorShape indicates a scalar.
200   return TensorShape();
201 }
202 
ConvertTypeToTensorShapeAttr(const mlir::Type & type)203 mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
204   if (type.isa<mlir::UnrankedTensorType>()) {
205     return mlir::TF::ShapeAttr::get(type.getContext(), llvm::None);
206   }
207 
208   if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
209     return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape());
210   }
211 
212   // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
213   // Empty TensorShape indicates a scalar.
214   return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef<int64_t>());
215 }
216 
217 // Converts the tensor shape proto into an MLIR shape attribute.
ConvertTensorShapeProto(const TensorShapeProto & shape,mlir::MLIRContext * context)218 StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
219                                                   mlir::MLIRContext* context) {
220   if (shape.unknown_rank())
221     return mlir::TF::ShapeAttr::get(context, llvm::None);
222 
223   llvm::SmallVector<int64_t, 4> dims;
224   dims.reserve(shape.dim().size());
225   for (const auto& dim : shape.dim()) {
226     dims.push_back(dim.size());
227   }
228   return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims));
229 }
230 
231 // Converts an MLIR dense string elements attribute to a TensorFlow tensor
232 // proto.
ConvertStringElementsAttr(const DenseStringElementsAttr attr,protobuf::RepeatedPtrField<std::string> * output)233 void ConvertStringElementsAttr(
234     const DenseStringElementsAttr attr,
235     protobuf::RepeatedPtrField<std::string>* output) {
236   for (const auto& val : attr.getRawStringData())
237     output->Add({val.data(), val.size()});
238 }
239 
240 template <typename T>
ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)241 void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
242                                 protobuf::RepeatedField<T>* output) {
243   for (const auto& val : attr.getValues<std::complex<T>>()) {
244     output->Add(val.real());
245     output->Add(val.imag());
246   }
247 }
248 
249 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
ConvertOpaqueElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)250 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
251                                  TensorProto* output_tensor) {
252   if (attr.isa<OpaqueElementsAttr>()) {
253     auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
254     absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
255     return mangling_util::DemangleTensor(tensor_view, output_tensor);
256   }
257   return InvalidArgument("Unexpected elements attribute type from MLIR.");
258 }
259 
260 // Converts an MLIR elements attribute and adds it to specified repeated field.
261 template <typename T>
ConvertElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)262 void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
263                          protobuf::RepeatedField<T>* output) {
264   if (attr.isSplat()) {
265     output->Add(attr.getSplatValue<T>());
266   } else {
267     output->Reserve(attr.getNumElements());
268     for (auto value : attr.getValues<T>()) output->AddAlreadyReserved(value);
269   }
270 }
271 
272 // Converts an MLIR elements attribute containing half values and adds it to
273 // specified repeated field.
ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)274 void ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,
275                              protobuf::RepeatedField<int>* output) {
276   if (attr.isSplat()) {
277     output->Add(attr.getSplatValue<Eigen::half>().x);
278   } else {
279     output->Reserve(attr.getNumElements());
280     for (const Eigen::half value : attr.getValues<Eigen::half>())
281       output->AddAlreadyReserved(value.x);
282   }
283 }
284 
285 // Converts an MLIR elements attribute containing int values and adds it to
286 // specified repeated field.
ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,protobuf::RepeatedField<int> * output)287 void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
288                             protobuf::RepeatedField<int>* output) {
289   if (attr.isSplat()) {
290     output->Add((*attr.begin()).getSExtValue());
291   } else {
292     output->Reserve(attr.getNumElements());
293     for (const llvm::APInt val : attr)
294       output->AddAlreadyReserved(val.getSExtValue());
295   }
296 }
297 
ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)298 void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,
299                                  protobuf::RepeatedField<int>* output) {
300   if (attr.isSplat()) {
301     output->Add(attr.getSplatValue<bfloat16>().value);
302   } else {
303     output->Reserve(attr.getNumElements());
304     for (const bfloat16 value : attr.getValues<bfloat16>())
305       output->AddAlreadyReserved(value.value);
306   }
307 }
308 
ConvertToTensorProto(const ElementsAttr attr,TensorProto * output)309 Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
310   auto type = attr.getType();
311   auto shape = type.getShape();
312   DataType output_dtype;
313   TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
314   output->set_dtype(output_dtype);
315   ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
316 
317   if (attr.isa<OpaqueElementsAttr>())
318     return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
319 
320   auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
321   if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
322 
323   switch (output_dtype) {
324     case DT_FLOAT:
325       ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
326       break;
327     case DT_HALF:
328       ConvertHalfElementsAttr(dense_attr, output->mutable_half_val());
329       break;
330     case DT_DOUBLE:
331       ConvertElementsAttr(dense_attr, output->mutable_double_val());
332       break;
333     case DT_QUINT8:
334     case DT_UINT8:
335     case DT_INT8:
336     case DT_QUINT16:
337     case DT_UINT16:
338     case DT_INT16:
339     case DT_INT32:
340       ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
341                              output->mutable_int_val());
342       break;
343     case DT_UINT32:
344       ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
345       break;
346     case DT_UINT64:
347       ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
348       break;
349     case DT_INT64:
350       ConvertElementsAttr(dense_attr, output->mutable_int64_val());
351       break;
352     case DT_BOOL:
353       ConvertElementsAttr(dense_attr, output->mutable_bool_val());
354       break;
355     case DT_BFLOAT16:
356       ConvertBfloat16ElementsAttr(dense_attr, output->mutable_half_val());
357       break;
358     case DT_STRING:
359       ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
360                                 output->mutable_string_val());
361       break;
362     case DT_COMPLEX64:
363       ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
364       break;
365     case DT_COMPLEX128:
366       ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
367       break;
368     default:
369       return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
370                                                 DataTypeString(output_dtype)));
371   }
372   return Status::OK();
373 }
374 
ConvertToTensor(const mlir::ElementsAttr attr,Tensor * output_tensor)375 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
376   TensorProto tensor_proto;
377   TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto));
378   if (!output_tensor->FromProto(tensor_proto)) {
379     return InvalidArgument("Couldn't convert tensor proto to tensor.");
380   }
381   return Status::OK();
382 }
383 
DecodeOpaqueTensor(const mlir::OpaqueElementsAttr input_attr,mlir::Builder builder)384 StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
385     const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
386   // TODO(antiagainst): The following logic, albeit simple, involves copying the
387   // tensor content multiple times, which is bad. Figure out a better way to
388   // achieve the purpose.
389   Tensor tensor;
390   TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
391   return ConvertTensor(tensor, &builder);
392 }
393 
394 }  // namespace tensorflow
395