• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
16 
17 #include <cstdlib>
18 
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/tfrt/utils/tensor_util.h"
31 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
32 #include "tfrt/host_context/attribute_utils.h"  // from @tf_runtime
33 #include "tfrt/support/error_util.h"  // from @tf_runtime
34 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
35 #include "tfrt/support/logging.h"  // from @tf_runtime
36 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
37 #include "tfrt/tensor/tensor_serialize_utils.h"  // from @tf_runtime
38 
39 namespace tensorflow {
40 namespace tfd {
41 namespace {
42 
43 using ::tensorflow::protobuf::RepeatedFieldBackInserter;
44 using ::tfrt::AggregateAttr;
45 using ::tfrt::BEFAttributeType;
46 using ::tfrt::DenseAttr;
47 using ::tfrt::DenseHostTensor;
48 using ::tfrt::HostContext;
49 using ::tfrt::OpAttrsRawEntry;
50 using ::tfrt::OpAttrsRef;
51 using ::tfrt::OpAttrType;
52 using ::tfrt::SmallVector;
53 using ::tfrt::string_view;
54 
DecodeDenseAttrToTfTensor(const DenseAttr & dense_attr,HostContext * host)55 llvm::Expected<tensorflow::Tensor> DecodeDenseAttrToTfTensor(
56     const DenseAttr& dense_attr, HostContext* host) {
57   llvm::Expected<DenseHostTensor> dht =
58       tfrt::DeserializeDenseHostTensorFromDenseAttr(dense_attr, host);
59   if (!dht) {
60     return tfrt::MakeStringError(
61         "Cannot create DenseHostTensor in DecodeDenseAttrToTensorInterface: ",
62         dht.takeError());
63   }
64 
65   return tfrt::TFRTTensorToTFTensor(*dht, host);
66 }
67 
FillAttrValueMapUsingArray(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)68 llvm::Error FillAttrValueMapUsingArray(const OpAttrsRawEntry& entry,
69                                        AttrValue& attr_tmp,
70                                        const OpAttrsRef& attrs) {
71   attr_tmp.mutable_list()->Clear();
72   if (entry.element_count == 0) {
73     if (entry.type == OpAttrType::CHAR) {
74       // Empty string.
75       attr_tmp.set_s("");
76     }
77     // Empty array of other types.
78     return llvm::Error::success();
79   }
80   switch (entry.type) {
81     case OpAttrType::CHAR: {
82       string_view attr_value = attrs.GetStringAsserting(entry.name);
83       attr_tmp.set_s(attr_value.data(), attr_value.size());
84       return llvm::Error::success();
85     }
86 
87     case OpAttrType::FUNC: {
88       string_view attr_value = attrs.GetFuncNameAsserting(entry.name);
89       attr_tmp.mutable_func()->set_name(attr_value.data(), attr_value.size());
90       return llvm::Error::success();
91     }
92     case OpAttrType::I64: {
93       llvm::ArrayRef<int64_t> int_array =
94           attrs.GetArrayAsserting<int64_t>(entry.name);
95       auto* mutable_i = attr_tmp.mutable_list()->mutable_i();
96       std::copy(int_array.begin(), int_array.end(),
97                 RepeatedFieldBackInserter(mutable_i));
98       return llvm::Error::success();
99     }
100     case OpAttrType::F32: {
101       llvm::ArrayRef<float> float_array =
102           attrs.GetArrayAsserting<float>(entry.name);
103       auto* mutable_f = attr_tmp.mutable_list()->mutable_f();
104       std::copy(float_array.begin(), float_array.end(),
105                 RepeatedFieldBackInserter(mutable_f));
106       return llvm::Error::success();
107     }
108     case OpAttrType::BOOL: {
109       llvm::ArrayRef<bool> bool_array =
110           attrs.GetArrayAsserting<bool>(entry.name);
111       auto mutable_b = attr_tmp.mutable_list()->mutable_b();
112       std::copy(bool_array.begin(), bool_array.end(),
113                 RepeatedFieldBackInserter(mutable_b));
114       return llvm::Error::success();
115     }
116     case OpAttrType::DTYPE: {
117       const auto& op_attr = attrs.GetRawAsserting(entry.name);
118       assert(op_attr.IsArray());
119 
120       // DTypes in BEF attributes are tfrt::DType enums. So we need
121       // to convert then to tensorflow data types first.
122       auto bef_dtypes =
123           llvm::makeArrayRef(static_cast<const tfrt::DType*>(op_attr.GetData()),
124                              op_attr.element_count);
125 
126       SmallVector<tensorflow::DataType, 4> tf_dtypes;
127       tf_dtypes.reserve(bef_dtypes.size());
128       for (auto bef_dtype : bef_dtypes) {
129         tf_dtypes.push_back(ConvertBefAttrTypeToTfDataType(bef_dtype));
130       }
131       auto* mutable_type = attr_tmp.mutable_list()->mutable_type();
132       std::copy(tf_dtypes.begin(), tf_dtypes.end(),
133                 RepeatedFieldBackInserter(mutable_type));
134       return llvm::Error::success();
135     }
136     default:
137       return tfrt::MakeStringError("unsupported array attribute type");
138   }
139 }
140 
FillAttrValueMapUsingAggregate(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)141 llvm::Error FillAttrValueMapUsingAggregate(const OpAttrsRawEntry& entry,
142                                            AttrValue& attr_tmp,
143                                            const OpAttrsRef& attrs) {
144   AggregateAttr list_attr = attrs.GetAsserting<AggregateAttr>(entry.name);
145   int num_values = list_attr.GetNumElements();
146   if (num_values == 0) {
147     // Create an empty list.
148     attr_tmp.mutable_list();
149     return llvm::Error::success();
150   }
151   // It is guaranteed that items in one list attribute have the same
152   // type, though their sizes can be different. In particular,
153   // list(TensorShape) and list(Tensor) attribute types have to be
154   // encoded as AggregateAttr.
155   auto attr_base = list_attr.GetAttribute(0);
156   auto* mutable_list = attr_tmp.mutable_list();
157   mutable_list->Clear();
158   if (IsDataTypeAttribute(attr_base.type()) &&
159       GetDataType(attr_base.type()) == tfrt::DType::String) {
160     // Handle list(string).
161     auto* mutable_s = mutable_list->mutable_s();
162     mutable_s->Reserve(num_values);
163     for (int i = 0; i < num_values; ++i) {
164       auto string_attr = list_attr.GetAttributeOfType<tfrt::StringAttr>(i);
165       mutable_list->add_s(string_attr.GetValue().data(),
166                           string_attr.GetValue().size());
167     }
168   } else if (attr_base.type() == BEFAttributeType::kFunc) {
169     // Handle list(Function).
170     auto* mutable_f = mutable_list->mutable_func();
171     mutable_f->Reserve(num_values);
172     for (int i = 0; i < num_values; ++i) {
173       auto func_attr = list_attr.GetAttributeOfType<tfrt::FuncAttr>(i);
174       auto mutable_func = mutable_list->add_func();
175       mutable_func->set_name(func_attr.GetFunctionName().str());
176     }
177   } else if (attr_base.type() == BEFAttributeType::kShape) {
178     // Handle list(TensorShape).
179     auto* mutable_list = attr_tmp.mutable_list();
180     auto* mutable_shape = mutable_list->mutable_shape();
181     mutable_shape->Reserve(num_values);
182     for (int i = 0; i < num_values; ++i) {
183       auto shape_attr = list_attr.GetAttributeOfType<tfrt::ShapeAttr>(i);
184       auto* added_shape = mutable_list->add_shape();
185       if (shape_attr.HasRank()) {
186         int rank = shape_attr.GetRank();
187         auto shape = shape_attr.GetShape();
188         added_shape->mutable_dim()->Reserve(rank);
189         for (int d = 0; d < rank; ++d) {
190           added_shape->add_dim()->set_size(shape[d]);
191         }
192       } else {
193         added_shape->set_unknown_rank(true);
194       }
195     }
196   } else {
197     return tfrt::MakeStringError("unsupported list attribute type");
198   }
199   return llvm::Error::success();
200 }
201 
FillAttrValueMapUsingScalar(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,HostContext * host,const OpAttrsRef & attrs)202 llvm::Error FillAttrValueMapUsingScalar(const OpAttrsRawEntry& entry,
203                                         AttrValue& attr_tmp, HostContext* host,
204                                         const OpAttrsRef& attrs) {
205   switch (entry.type) {
206     case OpAttrType::I64: {
207       int64_t attr_value = attrs.GetAsserting<int64_t>(entry.name);
208       attr_tmp.set_i(attr_value);
209       return llvm::Error::success();
210     }
211     case OpAttrType::F32: {
212       float attr_value = attrs.GetAsserting<float>(entry.name);
213       attr_tmp.set_f(attr_value);
214       return llvm::Error::success();
215     }
216     case OpAttrType::BOOL: {
217       bool attr_value = attrs.GetAsserting<bool>(entry.name);
218       attr_tmp.set_b(attr_value);
219       return llvm::Error::success();
220     }
221     case OpAttrType::DTYPE: {
222       OpAttrType op_attr_type = attrs.GetAsserting<OpAttrType>(entry.name);
223       DataType tf_dtype = ConvertToTfDataType(op_attr_type);
224       attr_tmp.set_type(tf_dtype);
225       return llvm::Error::success();
226     }
227     case OpAttrType::SHAPE: {
228       auto shape_attr = attrs.GetAsserting<tfrt::ShapeAttr>(entry.name);
229       auto* mutable_shape = attr_tmp.mutable_shape();
230       if (shape_attr.HasRank()) {
231         int rank = shape_attr.GetRank();
232         auto shape = shape_attr.GetShape();
233         mutable_shape->mutable_dim()->Reserve(rank);
234         for (int d = 0; d < rank; ++d) {
235           mutable_shape->add_dim()->set_size(shape[d]);
236         }
237       } else {
238         mutable_shape->set_unknown_rank(true);
239       }
240       return llvm::Error::success();
241     }
242     case OpAttrType::DENSE: {
243       auto dense_attr = attrs.GetAsserting<tfrt::DenseAttr>(entry.name);
244       llvm::Expected<tensorflow::Tensor> tf_tensor =
245           DecodeDenseAttrToTfTensor(dense_attr, host);
246       if (!tf_tensor) return tf_tensor.takeError();
247       auto* mutable_tensor = attr_tmp.mutable_tensor();
248       if (tf_tensor->NumElements() > 1) {
249         tf_tensor->AsProtoTensorContent(mutable_tensor);
250       } else {
251         tf_tensor->AsProtoField(mutable_tensor);
252       }
253       return llvm::Error::success();
254     }
255     case OpAttrType::AGGREGATE: {
256       return FillAttrValueMapUsingAggregate(entry, attr_tmp, attrs);
257     }
258     default:
259       LOG(ERROR) << "failure case";
260       return tfrt::MakeStringError("unsupported scalar attribute type");
261   }
262 }
263 
264 }  // namespace
265 
ParseTfDataType(absl::string_view dtype,DataType * data_type)266 Status ParseTfDataType(absl::string_view dtype, DataType* data_type) {
267   if (dtype == "DT_INT8") {
268     *data_type = DataType::DT_INT8;
269     return Status::OK();
270   } else if (dtype == "DT_INT32") {
271     *data_type = DataType::DT_INT32;
272     return Status::OK();
273   } else if (dtype == "DT_INT64") {
274     *data_type = DataType::DT_INT64;
275     return Status::OK();
276   } else if (dtype == "DT_HALF") {
277     *data_type = DataType::DT_HALF;
278     return Status::OK();
279   } else if (dtype == "DT_FLOAT") {
280     *data_type = DataType::DT_FLOAT;
281     return Status::OK();
282   } else if (dtype == "DT_DOUBLE") {
283     *data_type = DataType::DT_DOUBLE;
284     return Status::OK();
285   } else {
286     return errors::InvalidArgument("Unsupported dtype, ", std::string(dtype),
287                                    " in ParseTfDataType.");
288   }
289 }
290 
ConvertToTfDataType(tfrt::OpAttrType op_attr_type)291 DataType ConvertToTfDataType(tfrt::OpAttrType op_attr_type) {
292   switch (op_attr_type) {
293 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
294   case tfrt::OpAttrType::TFRT_ENUM:      \
295     return DataType::DT_ENUM;
296 #include "tensorflow/core/runtime_fallback/util/attr_type.def"  // NOLINT
297     default:
298       TFRT_DLOG(ERROR) << "unsupported dtype" << static_cast<int>(op_attr_type)
299                        << " in TFRT fallback kernel.";
300       abort();
301   }
302 }
303 
ConvertFromTfDataType(DataType data_type)304 tfrt::OpAttrType ConvertFromTfDataType(DataType data_type) {
305   switch (data_type) {
306 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
307   case DataType::DT_ENUM:                \
308     return tfrt::OpAttrType::TFRT_ENUM;
309 #include "tensorflow/core/runtime_fallback/util/attr_type.def"  // NOLINT
310     default:
311       TFRT_DLOG(ERROR) << "unsupported dtype " << static_cast<int>(data_type)
312                        << "in TFRT fallback kernel.";
313       abort();
314   }
315 }
316 
ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type)317 DataType ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type) {
318   switch (attr_type) {
319     case tfrt::DType::I1:
320       return DataType::DT_BOOL;
321     case tfrt::DType::I8:
322       return DataType::DT_INT8;
323     case tfrt::DType::I16:
324       return DataType::DT_INT16;
325     case tfrt::DType::I32:
326       return DataType::DT_INT32;
327     case tfrt::DType::I64:
328       return DataType::DT_INT64;
329     case tfrt::DType::UI8:
330       return DataType::DT_UINT8;
331     case tfrt::DType::UI16:
332       return DataType::DT_UINT16;
333     case tfrt::DType::UI32:
334       return DataType::DT_UINT32;
335     case tfrt::DType::UI64:
336       return DataType::DT_UINT64;
337     case tfrt::DType::F16:
338       return DataType::DT_HALF;
339     case tfrt::DType::BF16:
340       return DataType::DT_BFLOAT16;
341     case tfrt::DType::F32:
342       return DataType::DT_FLOAT;
343     case tfrt::DType::F64:
344       return DataType::DT_DOUBLE;
345     case tfrt::DType::Complex64:
346       return DataType::DT_COMPLEX64;
347     case tfrt::DType::Complex128:
348       return DataType::DT_COMPLEX128;
349     case tfrt::DType::String:
350       return DataType::DT_STRING;
351     case tfrt::DType::Resource:
352       return DataType::DT_RESOURCE;
353     case tfrt::DType::Variant:
354       return DataType::DT_VARIANT;
355     case tfrt::DType::QUI8:
356       return DataType::DT_QUINT8;
357     case tfrt::DType::QUI16:
358       return DataType::DT_QUINT16;
359     case tfrt::DType::QI8:
360       return DataType::DT_QINT8;
361     case tfrt::DType::QI16:
362       return DataType::DT_QINT16;
363     case tfrt::DType::QI32:
364       return DataType::DT_QINT32;
365     default:
366       TFRT_DLOG(ERROR) << "unsupported tfrt::DType"
367                        << static_cast<int>(attr_type)
368                        << " in TFRT fallback kernel.";
369       abort();
370   }
371 }
372 
ConvertTfDataTypeToBefAttrType(DataType data_type)373 tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type) {
374   switch (data_type) {
375     case DataType::DT_UINT8:
376       return tfrt::DType::UI8;
377     case DataType::DT_UINT16:
378       return tfrt::DType::UI16;
379     case DataType::DT_UINT32:
380       return tfrt::DType::UI32;
381     case DataType::DT_UINT64:
382       return tfrt::DType::UI64;
383     case DataType::DT_BOOL:
384       return tfrt::DType::I1;
385     case DataType::DT_INT8:
386       return tfrt::DType::I8;
387     case DataType::DT_INT16:
388       return tfrt::DType::I16;
389     case DataType::DT_INT32:
390       return tfrt::DType::I32;
391     case DataType::DT_INT64:
392       return tfrt::DType::I64;
393     case DataType::DT_HALF:
394       return tfrt::DType::F16;
395     case DataType::DT_BFLOAT16:
396       return tfrt::DType::BF16;
397     case DataType::DT_FLOAT:
398       return tfrt::DType::F32;
399     case DataType::DT_DOUBLE:
400       return tfrt::DType::F64;
401     case DataType::DT_COMPLEX64:
402       return tfrt::DType::Complex64;
403     case DataType::DT_COMPLEX128:
404       return tfrt::DType::Complex128;
405     case DataType::DT_STRING:
406       return tfrt::DType::String;
407     case DataType::DT_RESOURCE:
408       return tfrt::DType::Resource;
409     case DataType::DT_VARIANT:
410       return tfrt::DType::Variant;
411     case DataType::DT_QUINT8:
412       return tfrt::DType::QUI8;
413     case DataType::DT_QUINT16:
414       return tfrt::DType::QUI16;
415     case DataType::DT_QINT8:
416       return tfrt::DType::QI8;
417     case DataType::DT_QINT16:
418       return tfrt::DType::QI16;
419     case DataType::DT_QINT32:
420       return tfrt::DType::QI32;
421     default:
422       TFRT_DLOG(ERROR) << "unsupported DataType " << static_cast<int>(data_type)
423                        << " in TFRT fallback kernel.";
424       abort();
425   }
426 }
427 
ParseBoolAttrValue(absl::string_view attr_value,bool * bool_val)428 Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val) {
429   if (attr_value == "false") {
430     *bool_val = false;
431     return Status::OK();
432   } else if (attr_value == "true") {
433     *bool_val = true;
434     return Status::OK();
435   } else {
436     return errors::InvalidArgument("Could not parse bool from \"", attr_value,
437                                    "\"");
438   }
439 }
440 
ParseIntAttrValue(absl::string_view attr_value,int64_t * int_val)441 Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val) {
442   bool success = absl::SimpleAtoi(attr_value, int_val);
443   if (!success) {
444     return errors::InvalidArgument("Could not parse int from \"", attr_value,
445                                    "\"");
446   }
447   return Status::OK();
448 }
449 
ParseTensorAttrValue(absl::string_view attr_value,tensorflow::Tensor * tensor)450 Status ParseTensorAttrValue(absl::string_view attr_value,
451                             tensorflow::Tensor* tensor) {
452   if (std::is_base_of<tensorflow::protobuf::Message,
453                       tensorflow::TensorProto>()) {
454     tensorflow::TensorProto tensor_proto;
455     // We use reinterpret_cast here to make sure ParseFromString call
456     // below compiles if TensorProto is not a subclass of Message.
457     // At run time, we should never get to this point if TensorProto
458     // is not a subclass of message due to if-condition above.
459     auto* message = reinterpret_cast<protobuf::Message*>(&tensor_proto);
460     if (protobuf::TextFormat::ParseFromString(
461             static_cast<std::string>(attr_value), message) &&
462         tensor->FromProto(tensor_proto)) {
463       return Status::OK();
464     } else {
465       return errors::InvalidArgument("Could not parse tensor value from \"",
466                                      attr_value, "\"");
467     }
468   } else {
469     // TextFormat does not work with portable proto implementations.
470     return errors::InvalidArgument(
471         "Tensor attributes are not supported on mobile.");
472   }
473 }
474 
ParseTensorShapeAttrValue(absl::string_view attr_value,std::vector<int64_t> * shape_val)475 Status ParseTensorShapeAttrValue(absl::string_view attr_value,
476                                  std::vector<int64_t>* shape_val) {
477   if (attr_value.size() < 2 || attr_value[0] != '[' ||
478       attr_value[attr_value.size() - 1] != ']') {
479     return errors::InvalidArgument(
480         "Tensor shape attribute must be a string of the form [1,2...], instead "
481         "got \"",
482         attr_value, "\"");
483   }
484   absl::string_view attr_value_trunc =
485       attr_value.substr(1, attr_value.size() - 2);
486   // `container` is an absl::strings_internal::Splitter, which is a
487   // lazy-splitting iterable. So we cannot get its size to reserve `dims`.
488   auto container = absl::StrSplit(attr_value_trunc, ',');
489   for (auto it = container.begin(); it != container.end(); ++it) {
490     int64_t int_val;
491     if (!ParseIntAttrValue(*it, &int_val).ok()) {
492       return errors::InvalidArgument("Failed to parse an integer value from ",
493                                      *it, " while parsing shape.");
494     }
495     shape_val->push_back(int_val);
496   }
497   return Status::OK();
498 }
499 
IsUnusedAttribute(absl::string_view attr_name)500 bool IsUnusedAttribute(absl::string_view attr_name) {
501   // These are extra attributes added by TF MLIR dialect, and not needed by
502   // current TF runtime.
503   //
504   // TODO(chky): Consider removing this attribute in tf-to-tfrt
505   // lowering.
506   return absl::StrContains(attr_name, "result_segment_sizes") ||
507          absl::StrContains(attr_name, "operand_segment_sizes") ||
508          absl::EndsWith(attr_name, "_tf_data_function");
509 }
510 
FillAttrValueMap(const tfrt::OpAttrsRef & attrs,tfrt::HostContext * host,tensorflow::AttrValueMap * attr_value_map)511 llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs,
512                              tfrt::HostContext* host,
513                              tensorflow::AttrValueMap* attr_value_map) {
514   AttrValue attr_tmp;
515   llvm::Error error = llvm::Error::success();
516   attrs.IterateEntries([&error, attr_value_map, &attr_tmp, host,
517                         &attrs](const OpAttrsRawEntry& entry) {
518     // TFE does not expect a device attribute.
519     assert(strcmp(entry.name, "device") != 0);
520     if (IsUnusedAttribute(entry.name)) {
521       return;
522     } else if (entry.IsArray()) {
523       error = FillAttrValueMapUsingArray(entry, attr_tmp, attrs);
524     } else {
525       error = FillAttrValueMapUsingScalar(entry, attr_tmp, host, attrs);
526     }
527     if (error) return;
528     attr_value_map->insert(AttrValueMap::value_type(entry.name, attr_tmp));
529   });
530   return error;
531 }
532 
533 namespace {
534 
CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr)535 tensorflow::Tensor CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr) {
536   tensorflow::TensorShape shape(
537       absl::InlinedVector<int64, 4>(attr.shape().begin(), attr.shape().end()));
538   tensorflow::DataType dtype = ConvertBefAttrTypeToTfDataType(attr.dtype());
539 
540   tensorflow::Tensor tensor(dtype, shape);
541 
542   std::memcpy(tensor.data(), attr.GetElements(), tensor.TotalBytes());
543 
544   return tensor;
545 }
546 
SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,tensorflow::AttrValue * tf_attr)547 Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,
548                        tensorflow::AttrValue* tf_attr) {
549   if (auto shape_attr = bef_attr.dyn_cast<tfrt::ShapeAttr>()) {
550     if (shape_attr.HasRank()) {
551       tensorflow::PartialTensorShape tf_shape(shape_attr.GetShape());
552       tf_shape.AsProto(tf_attr->mutable_shape());
553     } else {
554       tensorflow::PartialTensorShape unranked_shape;
555       unranked_shape.AsProto(tf_attr->mutable_shape());
556     }
557   } else if (auto dense_attr = bef_attr.dyn_cast<tfrt::DenseAttr>()) {
558     auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
559     tf_tensor.AsProtoTensorContent(tf_attr->mutable_tensor());
560   } else if (auto type_attr = bef_attr.dyn_cast<tfrt::TypeAttr>()) {
561     tf_attr->set_type(ConvertBefAttrTypeToTfDataType(type_attr.GetValue()));
562   } else if (auto i1_attr = bef_attr.dyn_cast<tfrt::I1Attr>()) {
563     tf_attr->set_b(i1_attr.GetValue());
564   } else if (auto f32_attr = bef_attr.dyn_cast<tfrt::F32Attr>()) {
565     tf_attr->set_f(f32_attr.GetValue());
566   } else if (auto i64_attr = bef_attr.dyn_cast<tfrt::I64Attr>()) {
567     tf_attr->set_i(i64_attr.GetValue());
568   } else if (auto string_attr = bef_attr.dyn_cast<tfrt::StringAttr>()) {
569     tf_attr->set_s(string_attr.GetValue().data(),
570                    string_attr.GetValue().size());
571   } else {
572     return tensorflow::errors::Internal("Failed to set up attribute.");
573   }
574 
575   return Status::OK();
576 }
577 
SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,tensorflow::AttrValue & tf_attr)578 Status SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,
579                                tensorflow::AttrValue& tf_attr) {
580   tfrt::string_view func_name = func_attr.GetValue();
581   tf_attr.mutable_func()->set_name(func_name.data(), func_name.size());
582   return Status::OK();
583 }
584 
AddShapeToAttrList(tfrt::ShapeAttr shape,tensorflow::AttrValue::ListValue * list)585 void AddShapeToAttrList(tfrt::ShapeAttr shape,
586                         tensorflow::AttrValue::ListValue* list) {
587   if (shape.HasRank()) {
588     tensorflow::PartialTensorShape tf_shape(shape.GetShape());
589     tf_shape.AsProto(list->add_shape());
590     return;
591   }
592 
593   tensorflow::PartialTensorShape unranked_shape;
594   unranked_shape.AsProto(list->add_shape());
595 }
AddTensorToAttrList(tfrt::DenseAttr dense_attr,tensorflow::AttrValue::ListValue * list)596 void AddTensorToAttrList(tfrt::DenseAttr dense_attr,
597                          tensorflow::AttrValue::ListValue* list) {
598   auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
599   tf_tensor.AsProtoTensorContent(list->add_tensor());
600 }
601 
SetUpListAttr(tfrt::AggregateAttr aggregate_attr,tensorflow::AttrValue * tf_attr)602 Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr,
603                      tensorflow::AttrValue* tf_attr) {
604   auto* list = tf_attr->mutable_list();
605   for (int i = 0; i < aggregate_attr.GetNumElements(); ++i) {
606     auto base = aggregate_attr.GetAttribute(i);
607     if (auto shape_attr = base.dyn_cast<tfrt::ShapeAttr>()) {
608       AddShapeToAttrList(shape_attr, list);
609     } else if (auto dense_attr = base.dyn_cast<tfrt::DenseAttr>()) {
610       AddTensorToAttrList(dense_attr, list);
611     } else if (auto string_attr = base.dyn_cast<tfrt::StringAttr>()) {
612       list->add_s(string_attr.GetValue().data(), string_attr.GetValue().size());
613     } else {
614       return tensorflow::errors::Internal("Failed to set up list attr.");
615     }
616   }
617   return Status::OK();
618 }
619 
SetUpListAttr(tfrt::ArrayAttr array_attr,tensorflow::AttrValue * tf_attr)620 Status SetUpListAttr(tfrt::ArrayAttr array_attr,
621                      tensorflow::AttrValue* tf_attr) {
622   auto* list = tf_attr->mutable_list();
623 
624   // Handle an empty array case.
625   if (array_attr.GetNumElements() == 0) {
626     return Status::OK();
627   }
628 
629   tfrt::BEFAttributeType element_type = array_attr.GetElementType();
630   if (tfrt::IsDataTypeAttribute(element_type)) {
631     tfrt::DType dtype = GetDataType(element_type);
632     switch (dtype) {
633       case tfrt::DType::I1: {
634         for (auto value : array_attr.GetValue<bool>()) {
635           list->add_b(value);
636         }
637         return Status::OK();
638       }
639       case tfrt::DType::I64: {
640         for (auto value : array_attr.GetValue<int64_t>()) {
641           list->add_i(value);
642         }
643         return Status::OK();
644       }
645       case tfrt::DType::F32: {
646         for (auto value : array_attr.GetValue<float>()) {
647           list->add_f(value);
648         }
649         return Status::OK();
650       }
651       default:
652         return tensorflow::errors::Internal(
653             StrCat("Failed to set up list attr: unsupported dtype: ",
654                    tfrt::DType(dtype)));
655     }
656   } else if (element_type == tfrt::BEFAttributeType::kType) {
657     for (auto value : array_attr.GetValue<tfrt::DType>()) {
658       list->add_type(ConvertBefAttrTypeToTfDataType(value));
659     }
660     return Status::OK();
661   }
662 
663   return tensorflow::errors::Internal("Failed to set up list attr.");
664 }
665 
666 }  // namespace
667 
SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tensorflow::AttrValueMap * attr_value_map)668 Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,
669                          tfrt::AggregateAttr op_func_attr_array,
670                          tensorflow::AttrValueMap* attr_value_map) {
671   auto obtain_name_attr_pair =
672       [](tfrt::AggregateAttr attr_array,
673          int i) -> std::pair<std::string, tfrt::TypedAttrBase> {
674     auto pair = attr_array.GetAttributeOfType<tfrt::AggregateAttr>(i);
675     assert(pair.GetNumElements() == 2);
676     return {pair.GetAttributeOfType<tfrt::StringAttr>(0).GetValue().str(),
677             pair.GetAttribute(1)};
678   };
679 
680   for (size_t i = 0, e = op_attr_array.GetNumElements(); i != e; ++i) {
681     auto name_attr_pair = obtain_name_attr_pair(op_attr_array, i);
682     if (IsUnusedAttribute(name_attr_pair.first)) continue;
683 
684     AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
685     tfrt::TypedAttrBase attr_value = name_attr_pair.second;
686     if (auto aggregate_attr = attr_value.dyn_cast<tfrt::AggregateAttr>()) {
687       TF_RETURN_IF_ERROR(SetUpListAttr(aggregate_attr, &tf_attr));
688     } else if (auto array_attr = attr_value.dyn_cast<tfrt::ArrayAttr>()) {
689       TF_RETURN_IF_ERROR(SetUpListAttr(array_attr, &tf_attr));
690     } else {
691       TF_RETURN_IF_ERROR(SetUpScalarAttr(attr_value, &tf_attr));
692     }
693   }
694 
695   for (size_t i = 0, e = op_func_attr_array.GetNumElements(); i != e; ++i) {
696     auto name_attr_pair = obtain_name_attr_pair(op_func_attr_array, i);
697     if (IsUnusedAttribute(name_attr_pair.first)) continue;
698 
699     AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
700     auto attr_value = name_attr_pair.second.dyn_cast<tfrt::StringAttr>();
701     TF_RETURN_IF_ERROR(SetUpScalarFunctionAttr(attr_value, tf_attr));
702   }
703 
704   return Status::OK();
705 }
706 
707 }  // namespace tfd
708 }  // namespace tensorflow
709