1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
16
17 #include <cstdlib>
18
19 #include "absl/strings/numbers.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/tfrt/utils/tensor_util.h"
31 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
32 #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime
33 #include "tfrt/support/error_util.h" // from @tf_runtime
34 #include "tfrt/support/forward_decls.h" // from @tf_runtime
35 #include "tfrt/support/logging.h" // from @tf_runtime
36 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
37 #include "tfrt/tensor/tensor_serialize_utils.h" // from @tf_runtime
38
39 namespace tensorflow {
40 namespace tfd {
41 namespace {
42
43 using ::tensorflow::protobuf::RepeatedFieldBackInserter;
44 using ::tfrt::AggregateAttr;
45 using ::tfrt::BEFAttributeType;
46 using ::tfrt::DenseAttr;
47 using ::tfrt::DenseHostTensor;
48 using ::tfrt::HostContext;
49 using ::tfrt::OpAttrsRawEntry;
50 using ::tfrt::OpAttrsRef;
51 using ::tfrt::OpAttrType;
52 using ::tfrt::SmallVector;
53 using ::tfrt::string_view;
54
DecodeDenseAttrToTfTensor(const DenseAttr & dense_attr,HostContext * host)55 llvm::Expected<tensorflow::Tensor> DecodeDenseAttrToTfTensor(
56 const DenseAttr& dense_attr, HostContext* host) {
57 llvm::Expected<DenseHostTensor> dht =
58 tfrt::DeserializeDenseHostTensorFromDenseAttr(dense_attr, host);
59 if (!dht) {
60 return tfrt::MakeStringError(
61 "Cannot create DenseHostTensor in DecodeDenseAttrToTensorInterface: ",
62 dht.takeError());
63 }
64
65 return tfrt::TFRTTensorToTFTensor(*dht, host);
66 }
67
FillAttrValueMapUsingArray(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)68 llvm::Error FillAttrValueMapUsingArray(const OpAttrsRawEntry& entry,
69 AttrValue& attr_tmp,
70 const OpAttrsRef& attrs) {
71 attr_tmp.mutable_list()->Clear();
72 if (entry.element_count == 0) {
73 if (entry.type == OpAttrType::CHAR) {
74 // Empty string.
75 attr_tmp.set_s("");
76 }
77 // Empty array of other types.
78 return llvm::Error::success();
79 }
80 switch (entry.type) {
81 case OpAttrType::CHAR: {
82 string_view attr_value = attrs.GetStringAsserting(entry.name);
83 attr_tmp.set_s(attr_value.data(), attr_value.size());
84 return llvm::Error::success();
85 }
86
87 case OpAttrType::FUNC: {
88 string_view attr_value = attrs.GetFuncNameAsserting(entry.name);
89 attr_tmp.mutable_func()->set_name(attr_value.data(), attr_value.size());
90 return llvm::Error::success();
91 }
92 case OpAttrType::I64: {
93 llvm::ArrayRef<int64_t> int_array =
94 attrs.GetArrayAsserting<int64_t>(entry.name);
95 auto* mutable_i = attr_tmp.mutable_list()->mutable_i();
96 std::copy(int_array.begin(), int_array.end(),
97 RepeatedFieldBackInserter(mutable_i));
98 return llvm::Error::success();
99 }
100 case OpAttrType::F32: {
101 llvm::ArrayRef<float> float_array =
102 attrs.GetArrayAsserting<float>(entry.name);
103 auto* mutable_f = attr_tmp.mutable_list()->mutable_f();
104 std::copy(float_array.begin(), float_array.end(),
105 RepeatedFieldBackInserter(mutable_f));
106 return llvm::Error::success();
107 }
108 case OpAttrType::BOOL: {
109 llvm::ArrayRef<bool> bool_array =
110 attrs.GetArrayAsserting<bool>(entry.name);
111 auto mutable_b = attr_tmp.mutable_list()->mutable_b();
112 std::copy(bool_array.begin(), bool_array.end(),
113 RepeatedFieldBackInserter(mutable_b));
114 return llvm::Error::success();
115 }
116 case OpAttrType::DTYPE: {
117 const auto& op_attr = attrs.GetRawAsserting(entry.name);
118 assert(op_attr.IsArray());
119
120 // DTypes in BEF attributes are tfrt::DType enums. So we need
121 // to convert then to tensorflow data types first.
122 auto bef_dtypes =
123 llvm::makeArrayRef(static_cast<const tfrt::DType*>(op_attr.GetData()),
124 op_attr.element_count);
125
126 SmallVector<tensorflow::DataType, 4> tf_dtypes;
127 tf_dtypes.reserve(bef_dtypes.size());
128 for (auto bef_dtype : bef_dtypes) {
129 tf_dtypes.push_back(ConvertBefAttrTypeToTfDataType(bef_dtype));
130 }
131 auto* mutable_type = attr_tmp.mutable_list()->mutable_type();
132 std::copy(tf_dtypes.begin(), tf_dtypes.end(),
133 RepeatedFieldBackInserter(mutable_type));
134 return llvm::Error::success();
135 }
136 default:
137 return tfrt::MakeStringError("unsupported array attribute type");
138 }
139 }
140
FillAttrValueMapUsingAggregate(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,const OpAttrsRef & attrs)141 llvm::Error FillAttrValueMapUsingAggregate(const OpAttrsRawEntry& entry,
142 AttrValue& attr_tmp,
143 const OpAttrsRef& attrs) {
144 AggregateAttr list_attr = attrs.GetAsserting<AggregateAttr>(entry.name);
145 int num_values = list_attr.GetNumElements();
146 if (num_values == 0) {
147 // Create an empty list.
148 attr_tmp.mutable_list();
149 return llvm::Error::success();
150 }
151 // It is guaranteed that items in one list attribute have the same
152 // type, though their sizes can be different. In particular,
153 // list(TensorShape) and list(Tensor) attribute types have to be
154 // encoded as AggregateAttr.
155 auto attr_base = list_attr.GetAttribute(0);
156 auto* mutable_list = attr_tmp.mutable_list();
157 mutable_list->Clear();
158 if (IsDataTypeAttribute(attr_base.type()) &&
159 GetDataType(attr_base.type()) == tfrt::DType::String) {
160 // Handle list(string).
161 auto* mutable_s = mutable_list->mutable_s();
162 mutable_s->Reserve(num_values);
163 for (int i = 0; i < num_values; ++i) {
164 auto string_attr = list_attr.GetAttributeOfType<tfrt::StringAttr>(i);
165 mutable_list->add_s(string_attr.GetValue().data(),
166 string_attr.GetValue().size());
167 }
168 } else if (attr_base.type() == BEFAttributeType::kFunc) {
169 // Handle list(Function).
170 auto* mutable_f = mutable_list->mutable_func();
171 mutable_f->Reserve(num_values);
172 for (int i = 0; i < num_values; ++i) {
173 auto func_attr = list_attr.GetAttributeOfType<tfrt::FuncAttr>(i);
174 auto mutable_func = mutable_list->add_func();
175 mutable_func->set_name(func_attr.GetFunctionName().str());
176 }
177 } else if (attr_base.type() == BEFAttributeType::kShape) {
178 // Handle list(TensorShape).
179 auto* mutable_list = attr_tmp.mutable_list();
180 auto* mutable_shape = mutable_list->mutable_shape();
181 mutable_shape->Reserve(num_values);
182 for (int i = 0; i < num_values; ++i) {
183 auto shape_attr = list_attr.GetAttributeOfType<tfrt::ShapeAttr>(i);
184 auto* added_shape = mutable_list->add_shape();
185 if (shape_attr.HasRank()) {
186 int rank = shape_attr.GetRank();
187 auto shape = shape_attr.GetShape();
188 added_shape->mutable_dim()->Reserve(rank);
189 for (int d = 0; d < rank; ++d) {
190 added_shape->add_dim()->set_size(shape[d]);
191 }
192 } else {
193 added_shape->set_unknown_rank(true);
194 }
195 }
196 } else {
197 return tfrt::MakeStringError("unsupported list attribute type");
198 }
199 return llvm::Error::success();
200 }
201
FillAttrValueMapUsingScalar(const OpAttrsRawEntry & entry,AttrValue & attr_tmp,HostContext * host,const OpAttrsRef & attrs)202 llvm::Error FillAttrValueMapUsingScalar(const OpAttrsRawEntry& entry,
203 AttrValue& attr_tmp, HostContext* host,
204 const OpAttrsRef& attrs) {
205 switch (entry.type) {
206 case OpAttrType::I64: {
207 int64_t attr_value = attrs.GetAsserting<int64_t>(entry.name);
208 attr_tmp.set_i(attr_value);
209 return llvm::Error::success();
210 }
211 case OpAttrType::F32: {
212 float attr_value = attrs.GetAsserting<float>(entry.name);
213 attr_tmp.set_f(attr_value);
214 return llvm::Error::success();
215 }
216 case OpAttrType::BOOL: {
217 bool attr_value = attrs.GetAsserting<bool>(entry.name);
218 attr_tmp.set_b(attr_value);
219 return llvm::Error::success();
220 }
221 case OpAttrType::DTYPE: {
222 OpAttrType op_attr_type = attrs.GetAsserting<OpAttrType>(entry.name);
223 DataType tf_dtype = ConvertToTfDataType(op_attr_type);
224 attr_tmp.set_type(tf_dtype);
225 return llvm::Error::success();
226 }
227 case OpAttrType::SHAPE: {
228 auto shape_attr = attrs.GetAsserting<tfrt::ShapeAttr>(entry.name);
229 auto* mutable_shape = attr_tmp.mutable_shape();
230 if (shape_attr.HasRank()) {
231 int rank = shape_attr.GetRank();
232 auto shape = shape_attr.GetShape();
233 mutable_shape->mutable_dim()->Reserve(rank);
234 for (int d = 0; d < rank; ++d) {
235 mutable_shape->add_dim()->set_size(shape[d]);
236 }
237 } else {
238 mutable_shape->set_unknown_rank(true);
239 }
240 return llvm::Error::success();
241 }
242 case OpAttrType::DENSE: {
243 auto dense_attr = attrs.GetAsserting<tfrt::DenseAttr>(entry.name);
244 llvm::Expected<tensorflow::Tensor> tf_tensor =
245 DecodeDenseAttrToTfTensor(dense_attr, host);
246 if (!tf_tensor) return tf_tensor.takeError();
247 auto* mutable_tensor = attr_tmp.mutable_tensor();
248 if (tf_tensor->NumElements() > 1) {
249 tf_tensor->AsProtoTensorContent(mutable_tensor);
250 } else {
251 tf_tensor->AsProtoField(mutable_tensor);
252 }
253 return llvm::Error::success();
254 }
255 case OpAttrType::AGGREGATE: {
256 return FillAttrValueMapUsingAggregate(entry, attr_tmp, attrs);
257 }
258 default:
259 LOG(ERROR) << "failure case";
260 return tfrt::MakeStringError("unsupported scalar attribute type");
261 }
262 }
263
264 } // namespace
265
ParseTfDataType(absl::string_view dtype,DataType * data_type)266 Status ParseTfDataType(absl::string_view dtype, DataType* data_type) {
267 if (dtype == "DT_INT8") {
268 *data_type = DataType::DT_INT8;
269 return Status::OK();
270 } else if (dtype == "DT_INT32") {
271 *data_type = DataType::DT_INT32;
272 return Status::OK();
273 } else if (dtype == "DT_INT64") {
274 *data_type = DataType::DT_INT64;
275 return Status::OK();
276 } else if (dtype == "DT_HALF") {
277 *data_type = DataType::DT_HALF;
278 return Status::OK();
279 } else if (dtype == "DT_FLOAT") {
280 *data_type = DataType::DT_FLOAT;
281 return Status::OK();
282 } else if (dtype == "DT_DOUBLE") {
283 *data_type = DataType::DT_DOUBLE;
284 return Status::OK();
285 } else {
286 return errors::InvalidArgument("Unsupported dtype, ", std::string(dtype),
287 " in ParseTfDataType.");
288 }
289 }
290
ConvertToTfDataType(tfrt::OpAttrType op_attr_type)291 DataType ConvertToTfDataType(tfrt::OpAttrType op_attr_type) {
292 switch (op_attr_type) {
293 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
294 case tfrt::OpAttrType::TFRT_ENUM: \
295 return DataType::DT_ENUM;
296 #include "tensorflow/core/runtime_fallback/util/attr_type.def" // NOLINT
297 default:
298 TFRT_DLOG(ERROR) << "unsupported dtype" << static_cast<int>(op_attr_type)
299 << " in TFRT fallback kernel.";
300 abort();
301 }
302 }
303
ConvertFromTfDataType(DataType data_type)304 tfrt::OpAttrType ConvertFromTfDataType(DataType data_type) {
305 switch (data_type) {
306 #define OP_ATTR_TYPE(TFRT_ENUM, DT_ENUM) \
307 case DataType::DT_ENUM: \
308 return tfrt::OpAttrType::TFRT_ENUM;
309 #include "tensorflow/core/runtime_fallback/util/attr_type.def" // NOLINT
310 default:
311 TFRT_DLOG(ERROR) << "unsupported dtype " << static_cast<int>(data_type)
312 << "in TFRT fallback kernel.";
313 abort();
314 }
315 }
316
ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type)317 DataType ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type) {
318 switch (attr_type) {
319 case tfrt::DType::I1:
320 return DataType::DT_BOOL;
321 case tfrt::DType::I8:
322 return DataType::DT_INT8;
323 case tfrt::DType::I16:
324 return DataType::DT_INT16;
325 case tfrt::DType::I32:
326 return DataType::DT_INT32;
327 case tfrt::DType::I64:
328 return DataType::DT_INT64;
329 case tfrt::DType::UI8:
330 return DataType::DT_UINT8;
331 case tfrt::DType::UI16:
332 return DataType::DT_UINT16;
333 case tfrt::DType::UI32:
334 return DataType::DT_UINT32;
335 case tfrt::DType::UI64:
336 return DataType::DT_UINT64;
337 case tfrt::DType::F16:
338 return DataType::DT_HALF;
339 case tfrt::DType::BF16:
340 return DataType::DT_BFLOAT16;
341 case tfrt::DType::F32:
342 return DataType::DT_FLOAT;
343 case tfrt::DType::F64:
344 return DataType::DT_DOUBLE;
345 case tfrt::DType::Complex64:
346 return DataType::DT_COMPLEX64;
347 case tfrt::DType::Complex128:
348 return DataType::DT_COMPLEX128;
349 case tfrt::DType::String:
350 return DataType::DT_STRING;
351 case tfrt::DType::Resource:
352 return DataType::DT_RESOURCE;
353 case tfrt::DType::Variant:
354 return DataType::DT_VARIANT;
355 case tfrt::DType::QUI8:
356 return DataType::DT_QUINT8;
357 case tfrt::DType::QUI16:
358 return DataType::DT_QUINT16;
359 case tfrt::DType::QI8:
360 return DataType::DT_QINT8;
361 case tfrt::DType::QI16:
362 return DataType::DT_QINT16;
363 case tfrt::DType::QI32:
364 return DataType::DT_QINT32;
365 default:
366 TFRT_DLOG(ERROR) << "unsupported tfrt::DType"
367 << static_cast<int>(attr_type)
368 << " in TFRT fallback kernel.";
369 abort();
370 }
371 }
372
ConvertTfDataTypeToBefAttrType(DataType data_type)373 tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type) {
374 switch (data_type) {
375 case DataType::DT_UINT8:
376 return tfrt::DType::UI8;
377 case DataType::DT_UINT16:
378 return tfrt::DType::UI16;
379 case DataType::DT_UINT32:
380 return tfrt::DType::UI32;
381 case DataType::DT_UINT64:
382 return tfrt::DType::UI64;
383 case DataType::DT_BOOL:
384 return tfrt::DType::I1;
385 case DataType::DT_INT8:
386 return tfrt::DType::I8;
387 case DataType::DT_INT16:
388 return tfrt::DType::I16;
389 case DataType::DT_INT32:
390 return tfrt::DType::I32;
391 case DataType::DT_INT64:
392 return tfrt::DType::I64;
393 case DataType::DT_HALF:
394 return tfrt::DType::F16;
395 case DataType::DT_BFLOAT16:
396 return tfrt::DType::BF16;
397 case DataType::DT_FLOAT:
398 return tfrt::DType::F32;
399 case DataType::DT_DOUBLE:
400 return tfrt::DType::F64;
401 case DataType::DT_COMPLEX64:
402 return tfrt::DType::Complex64;
403 case DataType::DT_COMPLEX128:
404 return tfrt::DType::Complex128;
405 case DataType::DT_STRING:
406 return tfrt::DType::String;
407 case DataType::DT_RESOURCE:
408 return tfrt::DType::Resource;
409 case DataType::DT_VARIANT:
410 return tfrt::DType::Variant;
411 case DataType::DT_QUINT8:
412 return tfrt::DType::QUI8;
413 case DataType::DT_QUINT16:
414 return tfrt::DType::QUI16;
415 case DataType::DT_QINT8:
416 return tfrt::DType::QI8;
417 case DataType::DT_QINT16:
418 return tfrt::DType::QI16;
419 case DataType::DT_QINT32:
420 return tfrt::DType::QI32;
421 default:
422 TFRT_DLOG(ERROR) << "unsupported DataType " << static_cast<int>(data_type)
423 << " in TFRT fallback kernel.";
424 abort();
425 }
426 }
427
ParseBoolAttrValue(absl::string_view attr_value,bool * bool_val)428 Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val) {
429 if (attr_value == "false") {
430 *bool_val = false;
431 return Status::OK();
432 } else if (attr_value == "true") {
433 *bool_val = true;
434 return Status::OK();
435 } else {
436 return errors::InvalidArgument("Could not parse bool from \"", attr_value,
437 "\"");
438 }
439 }
440
ParseIntAttrValue(absl::string_view attr_value,int64_t * int_val)441 Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val) {
442 bool success = absl::SimpleAtoi(attr_value, int_val);
443 if (!success) {
444 return errors::InvalidArgument("Could not parse int from \"", attr_value,
445 "\"");
446 }
447 return Status::OK();
448 }
449
ParseTensorAttrValue(absl::string_view attr_value,tensorflow::Tensor * tensor)450 Status ParseTensorAttrValue(absl::string_view attr_value,
451 tensorflow::Tensor* tensor) {
452 if (std::is_base_of<tensorflow::protobuf::Message,
453 tensorflow::TensorProto>()) {
454 tensorflow::TensorProto tensor_proto;
455 // We use reinterpret_cast here to make sure ParseFromString call
456 // below compiles if TensorProto is not a subclass of Message.
457 // At run time, we should never get to this point if TensorProto
458 // is not a subclass of message due to if-condition above.
459 auto* message = reinterpret_cast<protobuf::Message*>(&tensor_proto);
460 if (protobuf::TextFormat::ParseFromString(
461 static_cast<std::string>(attr_value), message) &&
462 tensor->FromProto(tensor_proto)) {
463 return Status::OK();
464 } else {
465 return errors::InvalidArgument("Could not parse tensor value from \"",
466 attr_value, "\"");
467 }
468 } else {
469 // TextFormat does not work with portable proto implementations.
470 return errors::InvalidArgument(
471 "Tensor attributes are not supported on mobile.");
472 }
473 }
474
ParseTensorShapeAttrValue(absl::string_view attr_value,std::vector<int64_t> * shape_val)475 Status ParseTensorShapeAttrValue(absl::string_view attr_value,
476 std::vector<int64_t>* shape_val) {
477 if (attr_value.size() < 2 || attr_value[0] != '[' ||
478 attr_value[attr_value.size() - 1] != ']') {
479 return errors::InvalidArgument(
480 "Tensor shape attribute must be a string of the form [1,2...], instead "
481 "got \"",
482 attr_value, "\"");
483 }
484 absl::string_view attr_value_trunc =
485 attr_value.substr(1, attr_value.size() - 2);
486 // `container` is an absl::strings_internal::Splitter, which is a
487 // lazy-splitting iterable. So we cannot get its size to reserve `dims`.
488 auto container = absl::StrSplit(attr_value_trunc, ',');
489 for (auto it = container.begin(); it != container.end(); ++it) {
490 int64_t int_val;
491 if (!ParseIntAttrValue(*it, &int_val).ok()) {
492 return errors::InvalidArgument("Failed to parse an integer value from ",
493 *it, " while parsing shape.");
494 }
495 shape_val->push_back(int_val);
496 }
497 return Status::OK();
498 }
499
IsUnusedAttribute(absl::string_view attr_name)500 bool IsUnusedAttribute(absl::string_view attr_name) {
501 // These are extra attributes added by TF MLIR dialect, and not needed by
502 // current TF runtime.
503 //
504 // TODO(chky): Consider removing this attribute in tf-to-tfrt
505 // lowering.
506 return absl::StrContains(attr_name, "result_segment_sizes") ||
507 absl::StrContains(attr_name, "operand_segment_sizes") ||
508 absl::EndsWith(attr_name, "_tf_data_function");
509 }
510
FillAttrValueMap(const tfrt::OpAttrsRef & attrs,tfrt::HostContext * host,tensorflow::AttrValueMap * attr_value_map)511 llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs,
512 tfrt::HostContext* host,
513 tensorflow::AttrValueMap* attr_value_map) {
514 AttrValue attr_tmp;
515 llvm::Error error = llvm::Error::success();
516 attrs.IterateEntries([&error, attr_value_map, &attr_tmp, host,
517 &attrs](const OpAttrsRawEntry& entry) {
518 // TFE does not expect a device attribute.
519 assert(strcmp(entry.name, "device") != 0);
520 if (IsUnusedAttribute(entry.name)) {
521 return;
522 } else if (entry.IsArray()) {
523 error = FillAttrValueMapUsingArray(entry, attr_tmp, attrs);
524 } else {
525 error = FillAttrValueMapUsingScalar(entry, attr_tmp, host, attrs);
526 }
527 if (error) return;
528 attr_value_map->insert(AttrValueMap::value_type(entry.name, attr_tmp));
529 });
530 return error;
531 }
532
533 namespace {
534
CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr)535 tensorflow::Tensor CreateTfTensorFromDenseAttr(tfrt::DenseAttr attr) {
536 tensorflow::TensorShape shape(
537 absl::InlinedVector<int64, 4>(attr.shape().begin(), attr.shape().end()));
538 tensorflow::DataType dtype = ConvertBefAttrTypeToTfDataType(attr.dtype());
539
540 tensorflow::Tensor tensor(dtype, shape);
541
542 std::memcpy(tensor.data(), attr.GetElements(), tensor.TotalBytes());
543
544 return tensor;
545 }
546
SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,tensorflow::AttrValue * tf_attr)547 Status SetUpScalarAttr(tfrt::TypedAttrBase bef_attr,
548 tensorflow::AttrValue* tf_attr) {
549 if (auto shape_attr = bef_attr.dyn_cast<tfrt::ShapeAttr>()) {
550 if (shape_attr.HasRank()) {
551 tensorflow::PartialTensorShape tf_shape(shape_attr.GetShape());
552 tf_shape.AsProto(tf_attr->mutable_shape());
553 } else {
554 tensorflow::PartialTensorShape unranked_shape;
555 unranked_shape.AsProto(tf_attr->mutable_shape());
556 }
557 } else if (auto dense_attr = bef_attr.dyn_cast<tfrt::DenseAttr>()) {
558 auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
559 tf_tensor.AsProtoTensorContent(tf_attr->mutable_tensor());
560 } else if (auto type_attr = bef_attr.dyn_cast<tfrt::TypeAttr>()) {
561 tf_attr->set_type(ConvertBefAttrTypeToTfDataType(type_attr.GetValue()));
562 } else if (auto i1_attr = bef_attr.dyn_cast<tfrt::I1Attr>()) {
563 tf_attr->set_b(i1_attr.GetValue());
564 } else if (auto f32_attr = bef_attr.dyn_cast<tfrt::F32Attr>()) {
565 tf_attr->set_f(f32_attr.GetValue());
566 } else if (auto i64_attr = bef_attr.dyn_cast<tfrt::I64Attr>()) {
567 tf_attr->set_i(i64_attr.GetValue());
568 } else if (auto string_attr = bef_attr.dyn_cast<tfrt::StringAttr>()) {
569 tf_attr->set_s(string_attr.GetValue().data(),
570 string_attr.GetValue().size());
571 } else {
572 return tensorflow::errors::Internal("Failed to set up attribute.");
573 }
574
575 return Status::OK();
576 }
577
SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,tensorflow::AttrValue & tf_attr)578 Status SetUpScalarFunctionAttr(tfrt::StringAttr func_attr,
579 tensorflow::AttrValue& tf_attr) {
580 tfrt::string_view func_name = func_attr.GetValue();
581 tf_attr.mutable_func()->set_name(func_name.data(), func_name.size());
582 return Status::OK();
583 }
584
AddShapeToAttrList(tfrt::ShapeAttr shape,tensorflow::AttrValue::ListValue * list)585 void AddShapeToAttrList(tfrt::ShapeAttr shape,
586 tensorflow::AttrValue::ListValue* list) {
587 if (shape.HasRank()) {
588 tensorflow::PartialTensorShape tf_shape(shape.GetShape());
589 tf_shape.AsProto(list->add_shape());
590 return;
591 }
592
593 tensorflow::PartialTensorShape unranked_shape;
594 unranked_shape.AsProto(list->add_shape());
595 }
AddTensorToAttrList(tfrt::DenseAttr dense_attr,tensorflow::AttrValue::ListValue * list)596 void AddTensorToAttrList(tfrt::DenseAttr dense_attr,
597 tensorflow::AttrValue::ListValue* list) {
598 auto tf_tensor = CreateTfTensorFromDenseAttr(dense_attr);
599 tf_tensor.AsProtoTensorContent(list->add_tensor());
600 }
601
SetUpListAttr(tfrt::AggregateAttr aggregate_attr,tensorflow::AttrValue * tf_attr)602 Status SetUpListAttr(tfrt::AggregateAttr aggregate_attr,
603 tensorflow::AttrValue* tf_attr) {
604 auto* list = tf_attr->mutable_list();
605 for (int i = 0; i < aggregate_attr.GetNumElements(); ++i) {
606 auto base = aggregate_attr.GetAttribute(i);
607 if (auto shape_attr = base.dyn_cast<tfrt::ShapeAttr>()) {
608 AddShapeToAttrList(shape_attr, list);
609 } else if (auto dense_attr = base.dyn_cast<tfrt::DenseAttr>()) {
610 AddTensorToAttrList(dense_attr, list);
611 } else if (auto string_attr = base.dyn_cast<tfrt::StringAttr>()) {
612 list->add_s(string_attr.GetValue().data(), string_attr.GetValue().size());
613 } else {
614 return tensorflow::errors::Internal("Failed to set up list attr.");
615 }
616 }
617 return Status::OK();
618 }
619
SetUpListAttr(tfrt::ArrayAttr array_attr,tensorflow::AttrValue * tf_attr)620 Status SetUpListAttr(tfrt::ArrayAttr array_attr,
621 tensorflow::AttrValue* tf_attr) {
622 auto* list = tf_attr->mutable_list();
623
624 // Handle an empty array case.
625 if (array_attr.GetNumElements() == 0) {
626 return Status::OK();
627 }
628
629 tfrt::BEFAttributeType element_type = array_attr.GetElementType();
630 if (tfrt::IsDataTypeAttribute(element_type)) {
631 tfrt::DType dtype = GetDataType(element_type);
632 switch (dtype) {
633 case tfrt::DType::I1: {
634 for (auto value : array_attr.GetValue<bool>()) {
635 list->add_b(value);
636 }
637 return Status::OK();
638 }
639 case tfrt::DType::I64: {
640 for (auto value : array_attr.GetValue<int64_t>()) {
641 list->add_i(value);
642 }
643 return Status::OK();
644 }
645 case tfrt::DType::F32: {
646 for (auto value : array_attr.GetValue<float>()) {
647 list->add_f(value);
648 }
649 return Status::OK();
650 }
651 default:
652 return tensorflow::errors::Internal(
653 StrCat("Failed to set up list attr: unsupported dtype: ",
654 tfrt::DType(dtype)));
655 }
656 } else if (element_type == tfrt::BEFAttributeType::kType) {
657 for (auto value : array_attr.GetValue<tfrt::DType>()) {
658 list->add_type(ConvertBefAttrTypeToTfDataType(value));
659 }
660 return Status::OK();
661 }
662
663 return tensorflow::errors::Internal("Failed to set up list attr.");
664 }
665
666 } // namespace
667
SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tensorflow::AttrValueMap * attr_value_map)668 Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array,
669 tfrt::AggregateAttr op_func_attr_array,
670 tensorflow::AttrValueMap* attr_value_map) {
671 auto obtain_name_attr_pair =
672 [](tfrt::AggregateAttr attr_array,
673 int i) -> std::pair<std::string, tfrt::TypedAttrBase> {
674 auto pair = attr_array.GetAttributeOfType<tfrt::AggregateAttr>(i);
675 assert(pair.GetNumElements() == 2);
676 return {pair.GetAttributeOfType<tfrt::StringAttr>(0).GetValue().str(),
677 pair.GetAttribute(1)};
678 };
679
680 for (size_t i = 0, e = op_attr_array.GetNumElements(); i != e; ++i) {
681 auto name_attr_pair = obtain_name_attr_pair(op_attr_array, i);
682 if (IsUnusedAttribute(name_attr_pair.first)) continue;
683
684 AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
685 tfrt::TypedAttrBase attr_value = name_attr_pair.second;
686 if (auto aggregate_attr = attr_value.dyn_cast<tfrt::AggregateAttr>()) {
687 TF_RETURN_IF_ERROR(SetUpListAttr(aggregate_attr, &tf_attr));
688 } else if (auto array_attr = attr_value.dyn_cast<tfrt::ArrayAttr>()) {
689 TF_RETURN_IF_ERROR(SetUpListAttr(array_attr, &tf_attr));
690 } else {
691 TF_RETURN_IF_ERROR(SetUpScalarAttr(attr_value, &tf_attr));
692 }
693 }
694
695 for (size_t i = 0, e = op_func_attr_array.GetNumElements(); i != e; ++i) {
696 auto name_attr_pair = obtain_name_attr_pair(op_func_attr_array, i);
697 if (IsUnusedAttribute(name_attr_pair.first)) continue;
698
699 AttrValue& tf_attr = (*attr_value_map)[name_attr_pair.first];
700 auto attr_value = name_attr_pair.second.dyn_cast<tfrt::StringAttr>();
701 TF_RETURN_IF_ERROR(SetUpScalarFunctionAttr(attr_value, tf_attr));
702 }
703
704 return Status::OK();
705 }
706
707 } // namespace tfd
708 } // namespace tensorflow
709