1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
17
18 #include <memory>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/string_view.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
30 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36
37 namespace tensorflow {
38
39 namespace {
40
41 // Sets type list attribute with the given `name` to the given `types`. If the
42 // attribute already exists with a different value, returns an error.
43 template <typename ContainerT,
44 typename = typename std::enable_if<
45 std::is_same<mlir::Type, decltype(*std::declval<ContainerT>()
46 .begin())>::value>::type>
SetTypeAttribute(absl::string_view name,ContainerT types,AttrValueMap * values)47 Status SetTypeAttribute(absl::string_view name, ContainerT types,
48 AttrValueMap* values) {
49 AttrValue value;
50 auto& type_list = *value.mutable_list();
51 for (auto type : types) {
52 DataType dtype;
53 TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype));
54 type_list.add_type(dtype);
55 }
56
57 auto result = values->insert({string(name), value});
58 assert(result.second && "cannot have multiple attributes with the same name");
59 (void)result;
60
61 return Status::OK();
62 }
63
64 // Sets shape list attribute with the given `name` to the given `shapes`. If the
65 // attribute already exists with a different value, returns an error.
66 template <typename ContainerT,
67 typename = typename std::enable_if<std::is_same<
68 llvm::Optional<llvm::ArrayRef<int64_t>>,
69 decltype(*std::declval<ContainerT>().begin())>::value>::type>
SetShapeAttribute(absl::string_view name,ContainerT shapes,AttrValueMap * values)70 Status SetShapeAttribute(absl::string_view name, ContainerT shapes,
71 AttrValueMap* values) {
72 AttrValue value;
73 auto& shape_list = *value.mutable_list();
74 for (const llvm::Optional<llvm::ArrayRef<int64_t>>& shape : shapes) {
75 TensorShapeProto& tshape = *shape_list.add_shape();
76 if (shape.hasValue()) {
77 for (int64_t dim : *shape) tshape.add_dim()->set_size(dim);
78 } else {
79 tshape.set_unknown_rank(true);
80 }
81 }
82
83 auto result = values->insert({string(name), value});
84 assert(result.second && "cannot have multiple attributes with the same name");
85 (void)result;
86
87 return Status::OK();
88 }
89
90 // Collects all the unregistered attributes for an TF dialect operation.
91 // Attributes "name" and "device" are not included because they are not part
92 // of an TF op attributes.
GetUnregisteredAttrs(mlir::Operation * inst,const tensorflow::OpRegistrationData * op_reg_data,absl::flat_hash_set<absl::string_view> * attrs_to_ignore)93 Status GetUnregisteredAttrs(
94 mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data,
95 absl::flat_hash_set<absl::string_view>* attrs_to_ignore) {
96 if (!op_reg_data) {
97 // This is likely a function call node, so we should continue.
98 return Status::OK();
99 }
100
101 // Collect all the registered attributes.
102 llvm::DenseSet<llvm::StringRef> registered_attrs;
103 registered_attrs.insert("name");
104 registered_attrs.insert("device");
105 for (const auto& attr_def : op_reg_data->op_def.attr()) {
106 registered_attrs.insert(attr_def.name());
107 }
108 // Attributes are not in the registered attributes set will be ignored.
109 for (auto& attr : inst->getAttrs()) {
110 auto attr_name = attr.first.c_str();
111 if (registered_attrs.find(attr_name) == registered_attrs.end()) {
112 attrs_to_ignore->insert(attr_name);
113 }
114 }
115 return Status::OK();
116 }
117
118 // Collects all attribute names to ignore in an MLIR operation when exporting to
119 // a TensorFlow NodeDef.
GetAttributesToIgnore(mlir::Operation * inst,mlir::DictionaryAttr derived_attrs,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs)120 StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
121 mlir::Operation* inst, mlir::DictionaryAttr derived_attrs,
122 const tensorflow::OpRegistrationData* op_reg_data,
123 bool ignore_unregistered_attrs) {
124 // The elements are owned by the MLIRContext.
125 absl::flat_hash_set<absl::string_view> attrs_to_ignore;
126
127 // We ignore attributes attached to the operation when there is already a
128 // derived attribute defined in ODS.
129 if (derived_attrs) {
130 for (auto derived_attr : derived_attrs) {
131 attrs_to_ignore.insert(
132 mlir::StringRefToView(derived_attr.first.strref()));
133 }
134 }
135
136 if (ignore_unregistered_attrs) {
137 TF_RETURN_IF_ERROR(
138 GetUnregisteredAttrs(inst, op_reg_data, &attrs_to_ignore));
139 }
140
141 if (inst->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
142 // TODO(b/146937733): Don't use <void> here.
143 llvm::StringRef attr_name = mlir::OpTrait::AttrSizedOperandSegments<
144 void>::getOperandSegmentSizeAttr();
145 attrs_to_ignore.insert(attr_name.data());
146 }
147
148 if (inst->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
149 // TODO(b/146937733): Don't use <void> here.
150 llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments<
151 void>::getResultSegmentSizeAttr();
152 attrs_to_ignore.insert(attr_name.data());
153 }
154
155 if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
156 attrs_to_ignore.insert("is_stateless");
157
158 if (llvm::isa<mlir::TF::WhileOp>(inst))
159 attrs_to_ignore.insert("shape_invariant");
160
161 return attrs_to_ignore;
162 }
163
164 // Populates all derived attributes of a MLIR operation in a proto
165 // map<string, AttrValue>.
PopulateDerivedAttributes(mlir::Operation * inst,llvm::StringRef name,mlir::DictionaryAttr derived_attrs,bool ignore_unregistered_attrs,AttrValueMap * attributes)166 Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name,
167 mlir::DictionaryAttr derived_attrs,
168 bool ignore_unregistered_attrs,
169 AttrValueMap* attributes) {
170 if (derived_attrs) {
171 TF_RETURN_WITH_CONTEXT_IF_ERROR(
172 ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{},
173 /*remove_ref_type=*/true, attributes),
174 "while converting derived attributes for node: ",
175 mlir::StringRefToView(name));
176 }
177
178 // Here we only add the shapes for the leading values with ShapedType,
179 // assuming values with non-ShapedType are put at the end of the result.
180 if (!ignore_unregistered_attrs && inst->getNumResults() > 0) {
181 auto values = inst->getResults();
182 auto begin = values.begin();
183 auto end = values.begin();
184 while (end != values.end() && (*end).getType().isa<mlir::ShapedType>())
185 end++;
186 if (begin != end) {
187 mlir::TF::ResultShapeRange output_shapes = {
188 mlir::TF::ResultShapeIterator(begin),
189 mlir::TF::ResultShapeIterator(end)};
190 TF_RETURN_IF_ERROR(
191 SetShapeAttribute("_output_shapes", output_shapes, attributes));
192 }
193 }
194
195 return Status::OK();
196 }
197
198 } // namespace
199
GetAttrValuesFromOperation(mlir::Operation * inst,llvm::StringRef name,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs,AttrValueMap * attributes)200 Status GetAttrValuesFromOperation(
201 mlir::Operation* inst, llvm::StringRef name,
202 const tensorflow::OpRegistrationData* op_reg_data,
203 bool ignore_unregistered_attrs, AttrValueMap* attributes) {
204 mlir::DictionaryAttr derived_attrs = nullptr;
205 if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(inst))
206 derived_attrs = interface.materializeDerivedAttributes();
207 TF_ASSIGN_OR_RETURN(auto attrs_to_ignore,
208 GetAttributesToIgnore(inst, derived_attrs, op_reg_data,
209 ignore_unregistered_attrs));
210 TF_RETURN_WITH_CONTEXT_IF_ERROR(
211 ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
212 /*remove_ref_type=*/false, attributes),
213 "while converting attributes for node: ", mlir::StringRefToView(name));
214 TF_RETURN_IF_ERROR(PopulateDerivedAttributes(
215 inst, name, derived_attrs, ignore_unregistered_attrs, attributes));
216 return Status::OK();
217 }
218
ConvertTFDialectOpToNodeDef(mlir::Operation * inst,llvm::StringRef name,bool ignore_unregistered_attrs)219 StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
220 mlir::Operation* inst, llvm::StringRef name,
221 bool ignore_unregistered_attrs) {
222 TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(inst, name));
223 TF_ASSIGN_OR_RETURN(auto op_name,
224 GetTensorFlowOpName(inst->getName().getStringRef()));
225 const tensorflow::OpRegistrationData* op_reg_data =
226 tensorflow::OpRegistry::Global()->LookUp(op_name.str());
227 TF_RETURN_IF_ERROR(GetAttrValuesFromOperation(inst, name, op_reg_data,
228 ignore_unregistered_attrs,
229 node_def->mutable_attr()));
230 return node_def;
231 }
232
233 } // namespace tensorflow
234