• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
17 
18 #include <cstdint>
19 #include <limits>
20 
21 #include "absl/base/casts.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor.pb.h"
37 #include "tensorflow/core/framework/tensor_shape.pb.h"
38 #include "tensorflow/core/framework/tensor_util.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/platform/bfloat16.h"
42 #include "tensorflow/core/platform/errors.h"
43 #include "tensorflow/core/platform/protobuf.h"
44 #include "tensorflow/core/platform/tstring.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46 
47 namespace tensorflow {
48 
49 using llvm::ArrayRef;
50 using llvm::SmallVector;
51 using mlir::Builder;
52 using mlir::DenseStringElementsAttr;
53 using mlir::ElementsAttr;
54 using mlir::OpaqueElementsAttr;
55 using mlir::RankedTensorType;
56 using mlir::ShapedType;
57 using mlir::Type;
58 using tensorflow::errors::InvalidArgument;
59 
ConvertToProto(const Tensor & input_tensor,bool use_tensor_content=true)60 static TensorProto ConvertToProto(const Tensor& input_tensor,
61                                   bool use_tensor_content = true) {
62   TensorProto tensor_proto;
63   // Using tensor content (mostly*) reduces serialization overhead during RPC
64   // calls, but is less human reader friendly. People reading protobufs are less
65   // frequent than serialization, so default to using tensor content
66   // representation.
67   // * For scalars and short strings it may be marginally worse and a more
68   //   intelligent decision could be made by caller.
69   if (use_tensor_content)
70     input_tensor.AsProtoTensorContent(&tensor_proto);
71   else
72     input_tensor.AsProtoField(&tensor_proto);
73   return tensor_proto;
74 }
75 
MangleTensor(const Tensor & tensor)76 static std::string MangleTensor(const Tensor& tensor) {
77   return mangling_util::MangleTensor(ConvertToProto(tensor));
78 }
79 
80 // Converts a TensorFlow tensor into an MLIR elements attribute.
81 template <typename T>
ConvertFlatTensor(const Tensor & input_tensor,ShapedType type)82 StatusOr<ElementsAttr> ConvertFlatTensor(const Tensor& input_tensor,
83                                          ShapedType type) {
84   auto arr = input_tensor.flat<T>();
85   return mlir::DenseElementsAttr::get(
86       type, llvm::makeArrayRef(arr.data(), arr.size()));
87 }
88 
ConvertBf16Tensor(const Tensor & input_tensor,RankedTensorType type)89 ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor,
90                                RankedTensorType type) {
91   auto buffer = llvm::makeArrayRef(static_cast<char*>(input_tensor.data()),
92                                    input_tensor.TotalBytes());
93   return mlir::DenseElementsAttr::getFromRawBuffer(
94       type, buffer,
95       /*isSplatBuffer=*/type.getNumElements() == 1);
96 }
97 
ConvertHalfTensor(const Tensor & tensor,RankedTensorType type)98 ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) {
99   auto buffer = llvm::makeArrayRef(static_cast<char*>(tensor.data()),
100                                    tensor.TotalBytes());
101   return mlir::DenseElementsAttr::getFromRawBuffer(
102       type, buffer,
103       /*isSplatBuffer=*/type.getNumElements() == 1);
104 }
105 
ConvertStringTensor(const Tensor & input_tensor,ShapedType type)106 StatusOr<ElementsAttr> ConvertStringTensor(const Tensor& input_tensor,
107                                            ShapedType type) {
108   // Extract to a vector of StringRefs for converting.
109   auto arr = input_tensor.flat<tstring>();
110   std::vector<mlir::StringRef> string_refs;
111   string_refs.reserve(arr.size());
112   for (int i = 0; i < arr.size(); i++) {
113     const auto& val = arr(i);
114     string_refs.push_back({val.data(), val.size()});
115   }
116 
117   return DenseStringElementsAttr::get(type, string_refs);
118 }
119 
ConvertTensor(const Tensor & input_tensor,Builder * builder)120 StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
121                                      Builder* builder) {
122   const auto& input_dtype = input_tensor.dtype();
123   const auto& input_shape = input_tensor.shape();
124   Type elt_type;
125   TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type));
126   SmallVector<int64_t, 4> shape;
127   ConvertToMlirShape(input_shape, &shape);
128   auto type = RankedTensorType::get(shape, elt_type);
129 
130 #define CONVERT_FLAT(DTYPE, CTYPE) \
131   case DTYPE:                      \
132     return ConvertFlatTensor<CTYPE>(input_tensor, type);
133 
134   // TODO(fengliuai): customize the conversions for quantized and string types.
135   switch (input_dtype) {
136     CONVERT_FLAT(DT_BOOL, bool)
137     CONVERT_FLAT(DT_FLOAT, float)
138     CONVERT_FLAT(DT_DOUBLE, double)
139     CONVERT_FLAT(DT_INT8, int8)
140     CONVERT_FLAT(DT_INT16, int16)
141     CONVERT_FLAT(DT_INT32, int32)
142     CONVERT_FLAT(DT_INT64, int64)
143     CONVERT_FLAT(DT_UINT8, uint8)
144     CONVERT_FLAT(DT_UINT16, uint16)
145     CONVERT_FLAT(DT_UINT32, uint32)
146     CONVERT_FLAT(DT_UINT64, uint64)
147     CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
148     CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
149 
150     // BFLOAT16 is a special case that it needs to be cast to double type to
151     // match its storage type.
152     case DT_BFLOAT16:
153       return ConvertBf16Tensor(input_tensor, type);
154     case DT_HALF:
155       return ConvertHalfTensor(input_tensor, type);
156 
157     case DT_STRING:
158       return ConvertStringTensor(input_tensor, type);
159 
160     default:
161       // TODO(shpeisman): restructure code to reuse dialect pointer across
162       // calls.
163       auto* dialect = builder->getContext()->getLoadedDialect("tf");
164       return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor));
165   }
166 
167 #undef CONVERT_FLAT
168 }
169 
170 // Returns the number of elements present in this TensorProto, or -1 if that
171 // could not be determined. This might be less than the shape of the proto might
172 // indicate, if we're storing a splat tensor.
NumberOfMaterializedElements(const TensorProto & tensor)173 int NumberOfMaterializedElements(const TensorProto& tensor) {
174   if (!tensor.tensor_content().empty()) return -1;
175     // We don't know which element type this protocol buffer is storing, and the
176     // metaprogramming facilities for TensorProto are too limited to check their
177     // number without knowing this, so we need to manually dispatch to each
178     // possible member of TensorProto, depening on its dtype.
179 #define MATCH(DTYPE, FIELD) \
180   case DTYPE:               \
181     return tensor.FIELD##_val().size()
182 
183   switch (tensor.dtype()) {
184     MATCH(DT_FLOAT, float);
185     MATCH(DT_DOUBLE, double);
186     MATCH(DT_INT8, int);
187     MATCH(DT_UINT8, int);
188     MATCH(DT_INT16, int);
189     MATCH(DT_UINT16, int);
190     MATCH(DT_INT32, int);
191     MATCH(DT_UINT32, uint32);
192     MATCH(DT_INT64, int64);
193     MATCH(DT_UINT64, uint64);
194     MATCH(DT_BOOL, bool);
195     MATCH(DT_HALF, half);
196     MATCH(DT_BFLOAT16, half);
197     MATCH(DT_STRING, string);
198 
199     // TODO(b/188995810): DenseElementsAttr::get doesn't support complex
200     // Attributes being passed, so we bail out for now. This should just be
201     //   MATCH(DT_COMPLEX64, scomplex) / 2;
202     //   MATCH(DT_COMPLEX128, dcomplex) / 2;
203     // when DenseElementsAttr is updated.
204     case DT_COMPLEX64:
205     case DT_COMPLEX128:
206     default:
207       return -1;
208   }
209 }
210 
ConvertTensorProto(const TensorProto & input_tensor,Builder * builder)211 StatusOr<ElementsAttr> ConvertTensorProto(const TensorProto& input_tensor,
212                                           Builder* builder) {
213   // If there is only one actual element in the proto, but its shape would
214   // indicate there are more values, then this is representing a splat tensor.
215   // We can create an MLIR Attribute more efficiently in this case.
216   TensorShape input_tensor_shape(input_tensor.tensor_shape());
217   if (NumberOfMaterializedElements(input_tensor) == 1 &&
218       input_tensor_shape.num_elements() > 1) {
219     // We first convert this TensorProto to one of shape [1]. We then create an
220     // Attribute for that proto, and finally splat the Attribute.
221 
222     TensorProto tensor_copy = input_tensor;
223     auto* shape = tensor_copy.mutable_tensor_shape();
224     shape->clear_dim();
225     shape->add_dim()->set_size(1);
226 
227     TF_ASSIGN_OR_RETURN(ElementsAttr single_attr,
228                         ConvertTensorProto(tensor_copy, builder));
229 
230     std::vector<int64_t> original_dimensions;
231     for (auto dim : input_tensor_shape) original_dimensions.push_back(dim.size);
232     return mlir::SplatElementsAttr::get(
233         single_attr.getType().clone(original_dimensions),
234         single_attr.getValue({0}));
235   }
236 
237   Tensor t;
238   if (!t.FromProto(input_tensor))
239     return InvalidArgument("Failed to parse input_tensor.");
240   return ConvertTensor(t, builder);
241 }
242 
ConvertToTensorShapeProto(ArrayRef<int64_t> shape,TensorShapeProto * output_shape)243 void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
244                                TensorShapeProto* output_shape) {
245   for (auto d : shape) {
246     output_shape->add_dim()->set_size(d);
247   }
248 }
249 
ConvertTypeToTensorShape(const mlir::Type & type)250 PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
251   if (type.isa<mlir::UnrankedTensorType>()) {
252     // An empty PartialTensorShape indicates an unranked tensor.
253     return PartialTensorShape();
254   }
255 
256   if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
257     TensorShapeProto tensor_shape_proto;
258     ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
259     return PartialTensorShape(tensor_shape_proto);
260   }
261 
262   // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
263   // Empty TensorShape indicates a scalar.
264   return TensorShape();
265 }
266 
ConvertTypeToTensorShapeAttr(const mlir::Type & type)267 mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
268   if (type.isa<mlir::UnrankedTensorType>()) {
269     return mlir::TF::ShapeAttr::get(type.getContext(), llvm::None);
270   }
271 
272   if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
273     return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape());
274   }
275 
276   // If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
277   // Empty TensorShape indicates a scalar.
278   return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef<int64_t>());
279 }
280 
281 // Converts the tensor shape proto into an MLIR shape attribute.
ConvertTensorShapeProto(const TensorShapeProto & shape,mlir::MLIRContext * context)282 StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
283                                                   mlir::MLIRContext* context) {
284   if (shape.unknown_rank())
285     return mlir::TF::ShapeAttr::get(context, llvm::None);
286 
287   llvm::SmallVector<int64_t, 4> dims;
288   dims.reserve(shape.dim().size());
289   for (const auto& dim : shape.dim()) {
290     dims.push_back(dim.size());
291   }
292   return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims));
293 }
294 
295 // Converts an MLIR dense string elements attribute to a TensorFlow tensor
296 // proto.
ConvertStringElementsAttr(const DenseStringElementsAttr attr,protobuf::RepeatedPtrField<std::string> * output)297 void ConvertStringElementsAttr(
298     const DenseStringElementsAttr attr,
299     protobuf::RepeatedPtrField<std::string>* output) {
300   for (const auto& val : attr.getRawStringData())
301     output->Add({val.data(), val.size()});
302 }
303 
304 template <typename T>
ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)305 void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
306                                 protobuf::RepeatedField<T>* output) {
307   for (const auto& val : attr.getValues<std::complex<T>>()) {
308     output->Add(val.real());
309     output->Add(val.imag());
310   }
311 }
312 
313 // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
ConvertOpaqueElementsAttr(const ElementsAttr attr,TensorProto * output_tensor)314 Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
315                                  TensorProto* output_tensor) {
316   if (attr.isa<OpaqueElementsAttr>()) {
317     auto mangled_tensor = attr.cast<OpaqueElementsAttr>().getValue();
318     absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size());
319     return mangling_util::DemangleTensor(tensor_view, output_tensor);
320   }
321   return InvalidArgument("Unexpected elements attribute type from MLIR.");
322 }
323 
324 template <typename T>
ConvertElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output)325 void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
326                          protobuf::RepeatedField<T>* output) {
327   if (attr.isSplat()) {
328     if (attr.getSplatValue<T>() != T(0)) output->Add(attr.getSplatValue<T>());
329   } else {
330     output->Reserve(attr.getNumElements());
331     for (auto value : attr.getValues<T>()) output->AddAlreadyReserved(value);
332   }
333 }
334 
335 // Converts an MLIR elements attribute and adds it to specified repeated field.
336 template <typename T, typename Cord>
ConvertFloatElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output,Cord * tensor_content)337 void ConvertFloatElementsAttr(const mlir::DenseElementsAttr attr,
338                               protobuf::RepeatedField<T>* output,
339                               Cord* tensor_content) {
340   if (attr.isSplat()) {
341     if (attr.getSplatValue<T>() != T(0)) output->Add(attr.getSplatValue<T>());
342   } else {
343     port::CopyFromArray(tensor_content, attr.getRawData().data(),
344                         attr.getRawData().size());
345   }
346 }
347 
348 // Converts an MLIR elements attribute containing half values and adds it to
349 // specified repeated field.
ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)350 void ConvertHalfElementsAttr(const mlir::DenseElementsAttr attr,
351                              protobuf::RepeatedField<int>* output) {
352   if (attr.isSplat()) {
353     if (attr.getSplatValue<Eigen::half>() != Eigen::half(0))
354       output->Add(
355           Eigen::numext::bit_cast<uint16_t>(attr.getSplatValue<Eigen::half>()));
356   } else {
357     output->Reserve(attr.getNumElements());
358     for (const Eigen::half value : attr.getValues<Eigen::half>())
359       output->AddAlreadyReserved(Eigen::numext::bit_cast<uint16_t>(value));
360   }
361 }
362 
363 // Converts an MLIR elements attribute containing signed int values and adds it
364 // to specified repeated field.
365 template <typename T, typename U = T, typename Cord>
ConvertIntElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output,Cord * tensor_content)366 void ConvertIntElementsAttr(const mlir::DenseElementsAttr attr,
367                             protobuf::RepeatedField<T>* output,
368                             Cord* tensor_content) {
369   if (attr.isSplat()) {
370     if (attr.getSplatValue<U>() != U(0)) output->Add(attr.getSplatValue<U>());
371   } else {
372     port::CopyFromArray(tensor_content, attr.getRawData().data(),
373                         attr.getRawData().size());
374   }
375 }
376 
377 // Converts an MLIR elements attribute containing unsigned int values and adds
378 // it to specified repeated field.
379 template <typename T, typename U = T, typename Cord>
ConvertUIntElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<T> * output,Cord * tensor_content)380 void ConvertUIntElementsAttr(const mlir::DenseElementsAttr attr,
381                              protobuf::RepeatedField<T>* output,
382                              Cord* tensor_content) {
383   if (attr.isSplat()) {
384     if (attr.getSplatValue<U>() != U(0)) output->Add(attr.getSplatValue<U>());
385   } else {
386     port::CopyFromArray(tensor_content, attr.getRawData().data(),
387                         attr.getRawData().size());
388   }
389 }
390 
ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,protobuf::RepeatedField<int> * output)391 void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr,
392                                  protobuf::RepeatedField<int>* output) {
393   if (attr.isSplat()) {
394     if (attr.getSplatValue<bfloat16>() != bfloat16(0))
395       output->Add(
396           Eigen::numext::bit_cast<uint16_t>(attr.getSplatValue<bfloat16>()));
397   } else {
398     output->Reserve(attr.getNumElements());
399     for (const bfloat16 value : attr.getValues<bfloat16>())
400       output->AddAlreadyReserved(Eigen::numext::bit_cast<uint16_t>(value));
401   }
402 }
403 
ConvertToTensorProto(const ElementsAttr attr,TensorProto * output)404 Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
405   auto type = attr.getType();
406   auto shape = type.getShape();
407   DataType output_dtype;
408   TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
409   output->set_dtype(output_dtype);
410   ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
411 
412   if (attr.isa<OpaqueElementsAttr>())
413     return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
414 
415   auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
416   if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
417 
418   switch (output_dtype) {
419     case DT_BOOL:
420       ConvertElementsAttr(dense_attr, output->mutable_bool_val());
421       break;
422     case DT_BFLOAT16:
423       ConvertBfloat16ElementsAttr(dense_attr, output->mutable_half_val());
424       break;
425     case DT_COMPLEX64:
426       ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
427       break;
428     case DT_COMPLEX128:
429       ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
430       break;
431     case DT_DOUBLE:
432       ConvertFloatElementsAttr(dense_attr, output->mutable_double_val(),
433                                output->mutable_tensor_content());
434       break;
435     case DT_HALF:
436       ConvertHalfElementsAttr(dense_attr, output->mutable_half_val());
437       break;
438     case DT_FLOAT:
439       ConvertFloatElementsAttr(dense_attr, output->mutable_float_val(),
440                                output->mutable_tensor_content());
441       break;
442     case DT_QUINT8:
443     case DT_INT8:
444       ConvertUIntElementsAttr<int, int8_t>(dense_attr,
445                                            output->mutable_int_val(),
446                                            output->mutable_tensor_content());
447       break;
448     case DT_QUINT16:
449     case DT_INT16:
450       ConvertIntElementsAttr<int, int16_t>(dense_attr,
451                                            output->mutable_int_val(),
452                                            output->mutable_tensor_content());
453       break;
454     case DT_INT32:
455       ConvertIntElementsAttr(dense_attr, output->mutable_int_val(),
456                              output->mutable_tensor_content());
457       break;
458     case DT_INT64:
459       ConvertIntElementsAttr(dense_attr, output->mutable_int64_val(),
460                              output->mutable_tensor_content());
461       break;
462     case DT_STRING:
463       ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
464                                 output->mutable_string_val());
465       break;
466     case DT_UINT8:
467       ConvertUIntElementsAttr<int, uint8_t>(dense_attr,
468                                             output->mutable_int_val(),
469                                             output->mutable_tensor_content());
470       break;
471     case DT_UINT16:
472       ConvertUIntElementsAttr<int, uint16_t>(dense_attr,
473                                              output->mutable_int_val(),
474                                              output->mutable_tensor_content());
475       break;
476     case DT_UINT32:
477       ConvertUIntElementsAttr(dense_attr, output->mutable_uint32_val(),
478                               output->mutable_tensor_content());
479       break;
480     case DT_UINT64:
481       ConvertUIntElementsAttr(dense_attr, output->mutable_uint64_val(),
482                               output->mutable_tensor_content());
483       break;
484     default:
485       return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
486                                                 DataTypeString(output_dtype)));
487   }
488   return Status::OK();
489 }
490 
ConvertToTensor(const mlir::ElementsAttr attr,Tensor * output_tensor)491 Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
492   TensorProto tensor_proto;
493   TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor_proto));
494   if (!output_tensor->FromProto(tensor_proto)) {
495     return InvalidArgument("Couldn't convert tensor proto to tensor.");
496   }
497   return Status::OK();
498 }
499 
DecodeOpaqueTensor(const mlir::OpaqueElementsAttr input_attr,mlir::Builder builder)500 StatusOr<mlir::ElementsAttr> DecodeOpaqueTensor(
501     const mlir::OpaqueElementsAttr input_attr, mlir::Builder builder) {
502   // TODO(antiagainst): The following logic, albeit simple, involves copying the
503   // tensor content multiple times, which is bad. Figure out a better way to
504   // achieve the purpose.
505   Tensor tensor;
506   TF_RETURN_IF_ERROR(ConvertToTensor(input_attr, &tensor));
507   return ConvertTensor(tensor, &builder);
508 }
509 
510 }  // namespace tensorflow
511