• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
17 
18 #include <limits>
19 
20 #include "absl/base/casts.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/IR/Attributes.h"  // TF:llvm-project
28 #include "mlir/IR/Builders.h"  // TF:llvm-project
29 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
30 #include "mlir/IR/Types.h"  // TF:llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/stream_executor/lib/statusor.h"
42 
43 namespace tensorflow {
44 
45 using llvm::ArrayRef;
46 using llvm::SmallVector;
47 using mlir::Builder;
48 using mlir::DenseFPElementsAttr;
49 using mlir::DenseIntElementsAttr;
50 using mlir::ElementsAttr;
51 using mlir::OpaqueElementsAttr;
52 using mlir::RankedTensorType;
53 using mlir::ShapedType;
54 using mlir::Type;
55 using tensorflow::errors::InvalidArgument;
56 
ConvertToProto(const Tensor & input_tensor,bool use_tensor_content=true)57 static TensorProto ConvertToProto(const Tensor& input_tensor,
58                                   bool use_tensor_content = true) {
59   TensorProto tensor_proto;
60   // Using tensor content (mostly*) reduces serialization overhead during RPC
61   // calls, but is less human reader friendly. People reading protobufs are less
62   // frequent than serialization, so default to using tensor content
63   // representation.
64   // * For scalars and short strings it may be marginally worse and a more
65   //   intelligent decision could be made by caller.
66   if (use_tensor_content)
67     input_tensor.AsProtoTensorContent(&tensor_proto);
68   else
69     input_tensor.AsProtoField(&tensor_proto);
70   return tensor_proto;
71 }
72 
MangleTensor(const Tensor & tensor)73 static std::string MangleTensor(const Tensor& tensor) {
74   return mangling_util::MangleTensor(ConvertToProto(tensor));
75 }
76 
77 // Converts a TensorFlow tensor into an MLIR elements attribute.
78 template <typename T>
ConvertFlatTensor(const Tensor & input_tensor,ShapedType type)79 StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
80                                          ShapedType type) {
81   auto arr = input_tensor.flat<T>();
82   return mlir::DenseElementsAttr::get(
83       type, llvm::makeArrayRef(arr.data(), arr.size()));
84 }
85 
ConvertBF16Tensor(const Tensor & input_tensor,ShapedType type)86 StatusOr<ElementsAttr> ConvertBF16Tensor(const Tensor& input_tensor,
87                                          ShapedType type) {
88   auto flat = input_tensor.flat<bfloat16>();
89 
90   llvm::SmallVector<double, 4> flat_double;
91   flat_double.reserve(flat.size());
92   for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) {
93     flat_double.push_back(static_cast<double>(v));
94   }
95   return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(flat_double));
96 }
97 
ConvertTensor(const Tensor & input_tensor,Builder * builder)98 StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
99                                      Builder* builder) {
100   const auto& input_dtype = input_tensor.dtype();
101   const auto& input_shape = input_tensor.shape();
102   Type elt_type;
103   TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
104   SmallVector<int64_t, 4> shape;
105   ConvertToMlirShape(input_shape, &shape);
106   auto type = RankedTensorType::get(shape, elt_type);
107 
108 #define CONVERT_FLAT(DTYPE, CTYPE) \
109   case DTYPE:                      \
110     return ConvertFlatTensor<CTYPE>(input_tensor, type);
111 
112   // TODO(fengliuai): customize the conversions for more types.
113   switch (input_dtype) {
114     CONVERT_FLAT(DT_BOOL, bool)
115     CONVERT_FLAT(DT_FLOAT, float)
116     CONVERT_FLAT(DT_DOUBLE, double)
117     CONVERT_FLAT(DT_INT32, int32)
118     CONVERT_FLAT(DT_INT64, int64)
119 
120     // BFLOAT16 is a special case that it needs to be cast to double type to
121     // match its storage type.
122     case DT_BFLOAT16:
123       return ConvertBF16Tensor(input_tensor, type);
124 
125     default:
126       // TODO(shpeisman): restructure code to reuse dialect pointer across
127       // calls.
128       auto* dialect = builder->getContext()->getRegisteredDialect("tf");
129       return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
130   }
131 
132 #undef CONVERT_FLAT
133 }
134 
ConvertTensorProto(const TensorProto & input_tensor,Builder * builder)135 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
136                                           Builder* builder) {
137   Tensor t;
138   if (!t.FromProto(input_tensor))
139     return InvalidArgument("Failed to parse input_tensor.");
140   return ConvertTensor(t, builder);
141 }
142 
ConvertToTensorShapeProto(ArrayRef<int64_t> shape,TensorShapeProto * output_shape)143 void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
144                                TensorShapeProto* output_shape) {
145   for (auto d : shape) {
146     output_shape->add_dim()->set_size(d);
147   }
148 }
149 
ConvertTypeToTensorShape(const mlir::Type & type)150 PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
151   if (type.isa<mlir::UnrankedTensorType>()) {
152     // An empty PartialTensorShape indicates an unranked tensor.
153     return PartialTensorShape();
154   }
155 
156   if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
157     TensorShapeProto tensor_shape_proto;
158     ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
159     return PartialTensorShape(tensor_shape_proto);
160   }
161 
162   // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
163   // Empty TensorShape indicates a scalar.
164   return TensorShape();
165 }
166 
167 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
ConvertOpaqueElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)168 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
169                                  TensorProto* output_tensor) {
170   if (attr.isa<OpaqueElementsAttr>()) {
171     auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
172     absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
173     return mangling_util::DemangleTensor(tensor_view, output_tensor);
174   }
175   return InvalidArgument("Unexpected elements attribute type from MLIR.");
176 }
177 
178 // Converts an MLIR elements attribute to a TensorFlow tensor proto
179 // with the double_val field updated.
ConvertDoubleElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)180 Status ConvertDoubleElementsAttr(const ElementsAttr attr,
181                                  TensorProto* output_tensor) {
182   if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
183     if (elts.isSplat()) {
184       output_tensor->add_double_val(elts.getSplatValue<double>());
185     } else {
186       for (auto value : elts.getValues<double>())
187         output_tensor->add_double_val(value);
188     }
189     return Status::OK();
190   }
191   return ConvertOpaqueElementsAttr(attr, output_tensor);
192 }
193 
194 // Converts an MLIR elements attribute to a TensorFlow tensor proto
195 // with the float_val field updated.
ConvertFloatElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)196 Status ConvertFloatElementsAttr(const ElementsAttr attr,
197                                 TensorProto* output_tensor) {
198   if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
199     if (elts.isSplat()) {
200       output_tensor->add_float_val(elts.getSplatValue<float>());
201     } else {
202       for (auto value : elts.getValues<float>())
203         output_tensor->add_float_val(value);
204     }
205     return Status::OK();
206   }
207   return ConvertOpaqueElementsAttr(attr, output_tensor);
208 }
209 
210 // Converts an MLIR elements attribute to a TensorFlow tensor proto
211 // with the half_val field updated.
ConvertHalfElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)212 Status ConvertHalfElementsAttr(const ElementsAttr attr,
213                                TensorProto* output_tensor) {
214   if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
215     if (elts.isSplat()) {
216       output_tensor->add_half_val(
217           (*elts.begin()).bitcastToAPInt().getSExtValue());
218     } else {
219       for (auto value : elts.getFloatValues())
220         output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
221     }
222     return Status::OK();
223   }
224   return ConvertOpaqueElementsAttr(attr, output_tensor);
225 }
226 
227 // Converts an MLIR elements attribute to a TensorFlow tensor proto
228 // with the int_val field updated.
ConvertIntElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)229 Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
230                               TensorProto* output_tensor) {
231   if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
232     if (elts.isSplat()) {
233       output_tensor->add_int_val((*elts.begin()).getSExtValue());
234     } else {
235       for (auto val : elts) output_tensor->add_int_val(val.getSExtValue());
236     }
237     return Status::OK();
238   }
239   return ConvertOpaqueElementsAttr(attr, output_tensor);
240 }
241 
ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)242 Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
243                                    TensorProto* output_tensor) {
244   auto elts = attr.dyn_cast<DenseFPElementsAttr>();
245   if (!elts) {
246     return ConvertOpaqueElementsAttr(attr, output_tensor);
247   }
248 
249   // Bfloat16 is internally represented as `double` in MLIR.
250   if (elts.isSplat()) {
251     double v = elts.getSplatValue<double>();
252     bfloat16 bf16_val = static_cast<bfloat16>(v);
253     output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
254   } else {
255     for (auto v : elts.getValues<double>()) {
256       bfloat16 bf16_val = static_cast<bfloat16>(v);
257       output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
258     }
259   }
260 
261   return Status::OK();
262 }
263 
264 // Converts an MLIR elements attribute to a TensorFlow tensor proto
265 // with the int64_val field updated.
ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)266 Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
267                                 TensorProto* output_tensor) {
268   if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
269     if (elts.isSplat()) {
270       output_tensor->add_int64_val((*elts.begin()).getSExtValue());
271     } else {
272       for (auto val : elts) output_tensor->add_int64_val(val.getSExtValue());
273     }
274     return Status::OK();
275   }
276   return ConvertOpaqueElementsAttr(attr, output_tensor);
277 }
278 
279 // Converts an MLIR elements attribute to a TensorFlow tensor proto
280 // with bool_val field updated.
ConvertBoolElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)281 Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
282                                TensorProto* output_tensor) {
283   if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
284     for (auto val : elts) {
285       output_tensor->add_bool_val(val.getBoolValue());
286     }
287     return Status::OK();
288   }
289   return ConvertOpaqueElementsAttr(attr, output_tensor);
290 }
291 
ConvertToTensorProto(const ElementsAttr attr,TensorProto * output_tensor)292 Status ConvertToTensorProto(const ElementsAttr attr,
293                             TensorProto* output_tensor) {
294   auto type = attr.getType();
295   auto shape = type.getShape();
296   DataType output_dtype;
297   TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
298   output_tensor->set_dtype(output_dtype);
299   ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
300 
301   switch (output_dtype) {
302     case DT_FLOAT:
303       return ConvertFloatElementsAttr(attr, output_tensor);
304     case DT_HALF:
305       // Handles both DenseFPElementsAttr and OpaqueElementsAttr.
306       return ConvertHalfElementsAttr(attr, output_tensor);
307     case DT_DOUBLE:
308       return ConvertDoubleElementsAttr(attr, output_tensor);
309     case DT_QUINT8:
310     case DT_UINT8:
311     case DT_INT8:
312     case DT_QUINT16:
313     case DT_UINT16:
314     case DT_INT16:
315     case DT_INT32:
316       return ConvertIntElementsAttr(attr, output_tensor);
317     case DT_INT64:
318       return ConvertInt64ElementsAttr(attr, output_tensor);
319     case DT_BOOL:
320       return ConvertBoolElementsAttr(attr, output_tensor);
321     case DT_BFLOAT16:
322       return ConvertBfloat16ElementsAttr(attr, output_tensor);
323     default:
324       return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
325                                        output_tensor);
326   }
327 }
328 
ConvertToTensor(const mlir::ElementsAttr attr,Tensor * output_tensor)329 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
330   TensorProto tensor_proto;
331   TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto));
332   if (!output_tensor->FromProto(tensor_proto)) {
333     return InvalidArgument("Couldn't convert tensor proto to tensor.");
334   }
335   return Status::OK();
336 }
337 
DecodeOpaqueTensor(const mlir::OpaqueElementsAttr input_attr,mlir::Builder builder)338 StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
339     const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
340   // TODO(antiagainst): The following logic, albeit simple, involves copying the
341   // tensor content multiple times, which is bad. Figure out a better way to
342   // achieve the purpose.
343   Tensor tensor;
344   TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
345   return ConvertTensor(tensor, &builder);
346 }
347 
348 }  // namespace tensorflow
349