• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
17 
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
22 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 
26 namespace tensorflow {
27 
28 // Converts non func AttrValue proto into an MLIR attribute. Func attribute is
29 // exclused in this function because the function might be renamed when the
30 // function definition is imported.
ConvertNonFuncAttributeValue(const AttrValue & value,mlir::Builder * builder)31 StatusOr<mlir::Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
32                                                        mlir::Builder* builder) {
33   switch (value.value_case()) {
34     case AttrValue::kI:
35       return builder->getI64IntegerAttr(value.i());
36     case AttrValue::kS:
37       return builder->getStringAttr(value.s());
38     case AttrValue::kF:
39       return builder->getFloatAttr(builder->getF32Type(), value.f());
40     case AttrValue::kB:
41       return builder->getBoolAttr(value.b());
42     case AttrValue::kType: {
43       mlir::Type type;
44       TF_RETURN_IF_ERROR(ConvertDataType(value.type(), *builder, &type));
45       return mlir::TypeAttr::get(type);
46     }
47     case AttrValue::kShape:
48       return ConvertTensorShapeProto(value.shape(), builder->getContext());
49     case AttrValue::kTensor:
50       return ConvertTensorProto(value.tensor(), builder);
51     case AttrValue::kList: {
52       absl::InlinedVector<mlir::Attribute, 8> attrs;
53       for (const auto& item : value.list().i())
54         attrs.push_back(builder->getI64IntegerAttr(item));
55       for (const auto& item : value.list().s())
56         attrs.push_back(builder->getStringAttr(item));
57       for (const auto& item : value.list().f())
58         attrs.push_back(builder->getFloatAttr(builder->getF32Type(), item));
59       for (const auto& item : value.list().b())
60         attrs.push_back(builder->getBoolAttr(item));
61       for (const auto& item : value.list().type()) {
62         mlir::Type type;
63         TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), *builder, &type));
64         attrs.push_back(mlir::TypeAttr::get(type));
65       }
66       for (const auto& item : value.list().shape()) {
67         TF_ASSIGN_OR_RETURN(
68             auto attr, ConvertTensorShapeProto(item, builder->getContext()));
69         attrs.push_back(attr);
70       }
71       for (const auto& item : value.list().tensor()) {
72         TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder));
73         attrs.push_back(attr);
74       }
75       if (!value.list().func().empty()) {
76         return tensorflow::errors::Unimplemented(
77             absl::StrCat("Attribute ", value.DebugString()));
78       }
79       return builder->getArrayAttr(
80           llvm::makeArrayRef(attrs.begin(), attrs.end()));
81     }
82     case AttrValue::VALUE_NOT_SET:
83       return builder->getUnitAttr();
84     // kPlaceholder is not implemented.
85     case AttrValue::kPlaceholder:
86       return mlir::TF::PlaceholderAttr::get(builder->getContext(),
87                                             value.placeholder());
88     default:
89       return tensorflow::errors::Unimplemented(
90           absl::StrCat("Attribute ", value.DebugString()));
91   }
92 }
93 
ConvertAttributeValue(const AttrValue & value,mlir::Builder * builder)94 StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value,
95                                                 mlir::Builder* builder) {
96   switch (value.value_case()) {
97     case AttrValue::kFunc: {
98       // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
99       // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
100       // will not use this representation.
101       mlir::NamedAttrList attrs;
102       for (const auto& func_attr : value.func().attr()) {
103         TF_ASSIGN_OR_RETURN(auto attr,
104                             ConvertAttributeValue(func_attr.second, builder));
105         attrs.push_back(builder->getNamedAttr(func_attr.first, attr));
106       }
107       auto func_attrs = builder->getDictionaryAttr(attrs);
108       return mlir::TF::FuncAttr::get(builder->getContext(), value.func().name(),
109                                      func_attrs);
110     }
111     default:
112       return ConvertNonFuncAttributeValue(value, builder);
113   }
114 }
115 
116 }  // namespace tensorflow
117