1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
17
18 #include <limits>
19
20 #include "absl/base/casts.h"
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/IR/Attributes.h" // TF:llvm-project
28 #include "mlir/IR/Builders.h" // TF:llvm-project
29 #include "mlir/IR/StandardTypes.h" // TF:llvm-project
30 #include "mlir/IR/Types.h" // TF:llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor.pb.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/stream_executor/lib/statusor.h"
42
43 namespace tensorflow {
44
45 using llvm::ArrayRef;
46 using llvm::SmallVector;
47 using mlir::Builder;
48 using mlir::DenseFPElementsAttr;
49 using mlir::DenseIntElementsAttr;
50 using mlir::ElementsAttr;
51 using mlir::OpaqueElementsAttr;
52 using mlir::RankedTensorType;
53 using mlir::ShapedType;
54 using mlir::Type;
55 using tensorflow::errors::InvalidArgument;
56
ConvertToProto(const Tensor & input_tensor,bool use_tensor_content=true)57 static TensorProto ConvertToProto(const Tensor& input_tensor,
58 bool use_tensor_content = true) {
59 TensorProto tensor_proto;
60 // Using tensor content (mostly*) reduces serialization overhead during RPC
61 // calls, but is less human reader friendly. People reading protobufs are less
62 // frequent than serialization, so default to using tensor content
63 // representation.
64 // * For scalars and short strings it may be marginally worse and a more
65 // intelligent decision could be made by caller.
66 if (use_tensor_content)
67 input_tensor.AsProtoTensorContent(&tensor_proto);
68 else
69 input_tensor.AsProtoField(&tensor_proto);
70 return tensor_proto;
71 }
72
MangleTensor(const Tensor & tensor)73 static std::string MangleTensor(const Tensor& tensor) {
74 return mangling_util::MangleTensor(ConvertToProto(tensor));
75 }
76
77 // Converts a TensorFlow tensor into an MLIR elements attribute.
78 template <typename T>
ConvertFlatTensor(const Tensor & input_tensor,ShapedType type)79 StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
80 ShapedType type) {
81 auto arr = input_tensor.flat<T>();
82 return mlir::DenseElementsAttr::get(
83 type, llvm::makeArrayRef(arr.data(), arr.size()));
84 }
85
ConvertBF16Tensor(const Tensor & input_tensor,ShapedType type)86 StatusOr<ElementsAttr> ConvertBF16Tensor(const Tensor& input_tensor,
87 ShapedType type) {
88 auto flat = input_tensor.flat<bfloat16>();
89
90 llvm::SmallVector<double, 4> flat_double;
91 flat_double.reserve(flat.size());
92 for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) {
93 flat_double.push_back(static_cast<double>(v));
94 }
95 return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(flat_double));
96 }
97
ConvertTensor(const Tensor & input_tensor,Builder * builder)98 StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
99 Builder* builder) {
100 const auto& input_dtype = input_tensor.dtype();
101 const auto& input_shape = input_tensor.shape();
102 Type elt_type;
103 TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
104 SmallVector<int64_t, 4> shape;
105 ConvertToMlirShape(input_shape, &shape);
106 auto type = RankedTensorType::get(shape, elt_type);
107
108 #define CONVERT_FLAT(DTYPE, CTYPE) \
109 case DTYPE: \
110 return ConvertFlatTensor<CTYPE>(input_tensor, type);
111
112 // TODO(fengliuai): customize the conversions for more types.
113 switch (input_dtype) {
114 CONVERT_FLAT(DT_BOOL, bool)
115 CONVERT_FLAT(DT_FLOAT, float)
116 CONVERT_FLAT(DT_DOUBLE, double)
117 CONVERT_FLAT(DT_INT32, int32)
118 CONVERT_FLAT(DT_INT64, int64)
119
120 // BFLOAT16 is a special case that it needs to be cast to double type to
121 // match its storage type.
122 case DT_BFLOAT16:
123 return ConvertBF16Tensor(input_tensor, type);
124
125 default:
126 // TODO(shpeisman): restructure code to reuse dialect pointer across
127 // calls.
128 auto* dialect = builder->getContext()->getRegisteredDialect("tf");
129 return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
130 }
131
132 #undef CONVERT_FLAT
133 }
134
ConvertTensorProto(const TensorProto & input_tensor,Builder * builder)135 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
136 Builder* builder) {
137 Tensor t;
138 if (!t.FromProto(input_tensor))
139 return InvalidArgument("Failed to parse input_tensor.");
140 return ConvertTensor(t, builder);
141 }
142
ConvertToTensorShapeProto(ArrayRef<int64_t> shape,TensorShapeProto * output_shape)143 void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
144 TensorShapeProto* output_shape) {
145 for (auto d : shape) {
146 output_shape->add_dim()->set_size(d);
147 }
148 }
149
ConvertTypeToTensorShape(const mlir::Type & type)150 PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
151 if (type.isa<mlir::UnrankedTensorType>()) {
152 // An empty PartialTensorShape indicates an unranked tensor.
153 return PartialTensorShape();
154 }
155
156 if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
157 TensorShapeProto tensor_shape_proto;
158 ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
159 return PartialTensorShape(tensor_shape_proto);
160 }
161
162 // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
163 // Empty TensorShape indicates a scalar.
164 return TensorShape();
165 }
166
167 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
ConvertOpaqueElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)168 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
169 TensorProto* output_tensor) {
170 if (attr.isa<OpaqueElementsAttr>()) {
171 auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
172 absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
173 return mangling_util::DemangleTensor(tensor_view, output_tensor);
174 }
175 return InvalidArgument("Unexpected elements attribute type from MLIR.");
176 }
177
178 // Converts an MLIR elements attribute to a TensorFlow tensor proto
179 // with the double_val field updated.
ConvertDoubleElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)180 Status ConvertDoubleElementsAttr(const ElementsAttr attr,
181 TensorProto* output_tensor) {
182 if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
183 if (elts.isSplat()) {
184 output_tensor->add_double_val(elts.getSplatValue<double>());
185 } else {
186 for (auto value : elts.getValues<double>())
187 output_tensor->add_double_val(value);
188 }
189 return Status::OK();
190 }
191 return ConvertOpaqueElementsAttr(attr, output_tensor);
192 }
193
194 // Converts an MLIR elements attribute to a TensorFlow tensor proto
195 // with the float_val field updated.
ConvertFloatElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)196 Status ConvertFloatElementsAttr(const ElementsAttr attr,
197 TensorProto* output_tensor) {
198 if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
199 if (elts.isSplat()) {
200 output_tensor->add_float_val(elts.getSplatValue<float>());
201 } else {
202 for (auto value : elts.getValues<float>())
203 output_tensor->add_float_val(value);
204 }
205 return Status::OK();
206 }
207 return ConvertOpaqueElementsAttr(attr, output_tensor);
208 }
209
210 // Converts an MLIR elements attribute to a TensorFlow tensor proto
211 // with the half_val field updated.
ConvertHalfElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)212 Status ConvertHalfElementsAttr(const ElementsAttr attr,
213 TensorProto* output_tensor) {
214 if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
215 if (elts.isSplat()) {
216 output_tensor->add_half_val(
217 (*elts.begin()).bitcastToAPInt().getSExtValue());
218 } else {
219 for (auto value : elts.getFloatValues())
220 output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
221 }
222 return Status::OK();
223 }
224 return ConvertOpaqueElementsAttr(attr, output_tensor);
225 }
226
227 // Converts an MLIR elements attribute to a TensorFlow tensor proto
228 // with the int_val field updated.
ConvertIntElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)229 Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
230 TensorProto* output_tensor) {
231 if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
232 if (elts.isSplat()) {
233 output_tensor->add_int_val((*elts.begin()).getSExtValue());
234 } else {
235 for (auto val : elts) output_tensor->add_int_val(val.getSExtValue());
236 }
237 return Status::OK();
238 }
239 return ConvertOpaqueElementsAttr(attr, output_tensor);
240 }
241
ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)242 Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
243 TensorProto* output_tensor) {
244 auto elts = attr.dyn_cast<DenseFPElementsAttr>();
245 if (!elts) {
246 return ConvertOpaqueElementsAttr(attr, output_tensor);
247 }
248
249 // Bfloat16 is internally represented as `double` in MLIR.
250 if (elts.isSplat()) {
251 double v = elts.getSplatValue<double>();
252 bfloat16 bf16_val = static_cast<bfloat16>(v);
253 output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
254 } else {
255 for (auto v : elts.getValues<double>()) {
256 bfloat16 bf16_val = static_cast<bfloat16>(v);
257 output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
258 }
259 }
260
261 return Status::OK();
262 }
263
264 // Converts an MLIR elements attribute to a TensorFlow tensor proto
265 // with the int64_val field updated.
ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)266 Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
267 TensorProto* output_tensor) {
268 if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
269 if (elts.isSplat()) {
270 output_tensor->add_int64_val((*elts.begin()).getSExtValue());
271 } else {
272 for (auto val : elts) output_tensor->add_int64_val(val.getSExtValue());
273 }
274 return Status::OK();
275 }
276 return ConvertOpaqueElementsAttr(attr, output_tensor);
277 }
278
279 // Converts an MLIR elements attribute to a TensorFlow tensor proto
280 // with bool_val field updated.
ConvertBoolElementsAttr(const mlir::ElementsAttr attr,TensorProto * output_tensor)281 Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
282 TensorProto* output_tensor) {
283 if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
284 for (auto val : elts) {
285 output_tensor->add_bool_val(val.getBoolValue());
286 }
287 return Status::OK();
288 }
289 return ConvertOpaqueElementsAttr(attr, output_tensor);
290 }
291
ConvertToTensorProto(const ElementsAttr attr,TensorProto * output_tensor)292 Status ConvertToTensorProto(const ElementsAttr attr,
293 TensorProto* output_tensor) {
294 auto type = attr.getType();
295 auto shape = type.getShape();
296 DataType output_dtype;
297 TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
298 output_tensor->set_dtype(output_dtype);
299 ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
300
301 switch (output_dtype) {
302 case DT_FLOAT:
303 return ConvertFloatElementsAttr(attr, output_tensor);
304 case DT_HALF:
305 // Handles both DenseFPElementsAttr and OpaqueElementsAttr.
306 return ConvertHalfElementsAttr(attr, output_tensor);
307 case DT_DOUBLE:
308 return ConvertDoubleElementsAttr(attr, output_tensor);
309 case DT_QUINT8:
310 case DT_UINT8:
311 case DT_INT8:
312 case DT_QUINT16:
313 case DT_UINT16:
314 case DT_INT16:
315 case DT_INT32:
316 return ConvertIntElementsAttr(attr, output_tensor);
317 case DT_INT64:
318 return ConvertInt64ElementsAttr(attr, output_tensor);
319 case DT_BOOL:
320 return ConvertBoolElementsAttr(attr, output_tensor);
321 case DT_BFLOAT16:
322 return ConvertBfloat16ElementsAttr(attr, output_tensor);
323 default:
324 return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
325 output_tensor);
326 }
327 }
328
ConvertToTensor(const mlir::ElementsAttr attr,Tensor * output_tensor)329 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
330 TensorProto tensor_proto;
331 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto));
332 if (!output_tensor->FromProto(tensor_proto)) {
333 return InvalidArgument("Couldn't convert tensor proto to tensor.");
334 }
335 return Status::OK();
336 }
337
DecodeOpaqueTensor(const mlir::OpaqueElementsAttr input_attr,mlir::Builder builder)338 StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
339 const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
340 // TODO(antiagainst): The following logic, albeit simple, involves copying the
341 // tensor content multiple times, which is bad. Figure out a better way to
342 // achieve the purpose.
343 Tensor tensor;
344 TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
345 return ConvertTensor(tensor, &builder);
346 }
347
348 } // namespace tensorflow
349