• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
17 
18 #include <memory>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/string_view.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/Interfaces/DerivedAttributeOpInterface.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
30 #include "tensorflow/compiler/mlir/utils/string_container_utils.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 
41 // Sets type list attribute with the given `name` to the given `types`. If the
42 // attribute already exists with a different value, returns an error.
43 template <typename ContainerT,
44           typename = typename std::enable_if<
45               std::is_same<mlir::Type, decltype(*std::declval<ContainerT>()
46                                                      .begin())>::value>::type>
SetTypeAttribute(absl::string_view name,ContainerT types,AttrValueMap * values)47 Status SetTypeAttribute(absl::string_view name, ContainerT types,
48                         AttrValueMap* values) {
49   AttrValue value;
50   auto& type_list = *value.mutable_list();
51   for (auto type : types) {
52     DataType dtype;
53     TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype));
54     type_list.add_type(dtype);
55   }
56 
57   auto result = values->insert({string(name), value});
58   assert(result.second && "cannot have multiple attributes with the same name");
59   (void)result;
60 
61   return Status::OK();
62 }
63 
64 // Sets shape list attribute with the given `name` to the given `shapes`. If the
65 // attribute already exists with a different value, returns an error.
66 template <typename ContainerT,
67           typename = typename std::enable_if<std::is_same<
68               llvm::Optional<llvm::ArrayRef<int64_t>>,
69               decltype(*std::declval<ContainerT>().begin())>::value>::type>
SetShapeAttribute(absl::string_view name,ContainerT shapes,AttrValueMap * values)70 Status SetShapeAttribute(absl::string_view name, ContainerT shapes,
71                          AttrValueMap* values) {
72   AttrValue value;
73   auto& shape_list = *value.mutable_list();
74   for (const llvm::Optional<llvm::ArrayRef<int64_t>>& shape : shapes) {
75     TensorShapeProto& tshape = *shape_list.add_shape();
76     if (shape.hasValue()) {
77       for (int64_t dim : *shape) tshape.add_dim()->set_size(dim);
78     } else {
79       tshape.set_unknown_rank(true);
80     }
81   }
82 
83   auto result = values->insert({string(name), value});
84   assert(result.second && "cannot have multiple attributes with the same name");
85   (void)result;
86 
87   return Status::OK();
88 }
89 
90 // Collects all the unregistered attributes for an TF dialect operation.
91 // Attributes "name" and "device" are not included because they are not part
92 // of an TF op attributes.
GetUnregisteredAttrs(mlir::Operation * inst,const tensorflow::OpRegistrationData * op_reg_data,absl::flat_hash_set<absl::string_view> * attrs_to_ignore)93 Status GetUnregisteredAttrs(
94     mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data,
95     absl::flat_hash_set<absl::string_view>* attrs_to_ignore) {
96   if (!op_reg_data) {
97     // This is likely a function call node, so we should continue.
98     return Status::OK();
99   }
100 
101   // Collect all the registered attributes.
102   llvm::DenseSet<llvm::StringRef> registered_attrs;
103   registered_attrs.insert("name");
104   registered_attrs.insert("device");
105   for (const auto& attr_def : op_reg_data->op_def.attr()) {
106     registered_attrs.insert(attr_def.name());
107   }
108   // Attributes are not in the registered attributes set will be ignored.
109   for (auto& attr : inst->getAttrs()) {
110     auto attr_name = attr.first.c_str();
111     if (registered_attrs.find(attr_name) == registered_attrs.end()) {
112       attrs_to_ignore->insert(attr_name);
113     }
114   }
115   return Status::OK();
116 }
117 
118 // Collects all attribute names to ignore in an MLIR operation when exporting to
119 // a TensorFlow NodeDef.
GetAttributesToIgnore(mlir::Operation * inst,mlir::DictionaryAttr derived_attrs,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs)120 StatusOr<absl::flat_hash_set<absl::string_view>> GetAttributesToIgnore(
121     mlir::Operation* inst, mlir::DictionaryAttr derived_attrs,
122     const tensorflow::OpRegistrationData* op_reg_data,
123     bool ignore_unregistered_attrs) {
124   // The elements are owned by the MLIRContext.
125   absl::flat_hash_set<absl::string_view> attrs_to_ignore;
126 
127   // We ignore attributes attached to the operation when there is already a
128   // derived attribute defined in ODS.
129   if (derived_attrs) {
130     for (auto derived_attr : derived_attrs) {
131       attrs_to_ignore.insert(
132           mlir::StringRefToView(derived_attr.first.strref()));
133     }
134   }
135 
136   if (ignore_unregistered_attrs) {
137     TF_RETURN_IF_ERROR(
138         GetUnregisteredAttrs(inst, op_reg_data, &attrs_to_ignore));
139   }
140 
141   if (inst->hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
142     // TODO(b/146937733): Don't use <void> here.
143     llvm::StringRef attr_name = mlir::OpTrait::AttrSizedOperandSegments<
144         void>::getOperandSegmentSizeAttr();
145     attrs_to_ignore.insert(attr_name.data());
146   }
147 
148   if (inst->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
149     // TODO(b/146937733): Don't use <void> here.
150     llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments<
151         void>::getResultSegmentSizeAttr();
152     attrs_to_ignore.insert(attr_name.data());
153   }
154 
155   if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst))
156     attrs_to_ignore.insert("is_stateless");
157 
158   if (llvm::isa<mlir::TF::WhileOp>(inst))
159     attrs_to_ignore.insert("shape_invariant");
160 
161   return attrs_to_ignore;
162 }
163 
164 // Populates all derived attributes of a MLIR operation in a proto
165 // map<string, AttrValue>.
PopulateDerivedAttributes(mlir::Operation * inst,llvm::StringRef name,mlir::DictionaryAttr derived_attrs,bool ignore_unregistered_attrs,AttrValueMap * attributes)166 Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name,
167                                  mlir::DictionaryAttr derived_attrs,
168                                  bool ignore_unregistered_attrs,
169                                  AttrValueMap* attributes) {
170   if (derived_attrs) {
171     TF_RETURN_WITH_CONTEXT_IF_ERROR(
172         ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{},
173                           /*remove_ref_type=*/true, attributes),
174         "while converting derived attributes for node: ",
175         mlir::StringRefToView(name));
176   }
177 
178   // Here we only add the shapes for the leading values with ShapedType,
179   // assuming values with non-ShapedType are put at the end of the result.
180   if (!ignore_unregistered_attrs && inst->getNumResults() > 0) {
181     auto values = inst->getResults();
182     auto begin = values.begin();
183     auto end = values.begin();
184     while (end != values.end() && (*end).getType().isa<mlir::ShapedType>())
185       end++;
186     if (begin != end) {
187       mlir::TF::ResultShapeRange output_shapes = {
188           mlir::TF::ResultShapeIterator(begin),
189           mlir::TF::ResultShapeIterator(end)};
190       TF_RETURN_IF_ERROR(
191           SetShapeAttribute("_output_shapes", output_shapes, attributes));
192     }
193   }
194 
195   return Status::OK();
196 }
197 
198 }  // namespace
199 
GetAttrValuesFromOperation(mlir::Operation * inst,llvm::StringRef name,const tensorflow::OpRegistrationData * op_reg_data,bool ignore_unregistered_attrs,AttrValueMap * attributes)200 Status GetAttrValuesFromOperation(
201     mlir::Operation* inst, llvm::StringRef name,
202     const tensorflow::OpRegistrationData* op_reg_data,
203     bool ignore_unregistered_attrs, AttrValueMap* attributes) {
204   mlir::DictionaryAttr derived_attrs = nullptr;
205   if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(inst))
206     derived_attrs = interface.materializeDerivedAttributes();
207   TF_ASSIGN_OR_RETURN(auto attrs_to_ignore,
208                       GetAttributesToIgnore(inst, derived_attrs, op_reg_data,
209                                             ignore_unregistered_attrs));
210   TF_RETURN_WITH_CONTEXT_IF_ERROR(
211       ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
212                         /*remove_ref_type=*/false, attributes),
213       "while converting attributes for node: ", mlir::StringRefToView(name));
214   TF_RETURN_IF_ERROR(PopulateDerivedAttributes(
215       inst, name, derived_attrs, ignore_unregistered_attrs, attributes));
216   return Status::OK();
217 }
218 
ConvertTFDialectOpToNodeDef(mlir::Operation * inst,llvm::StringRef name,bool ignore_unregistered_attrs)219 StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
220     mlir::Operation* inst, llvm::StringRef name,
221     bool ignore_unregistered_attrs) {
222   TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(inst, name));
223   TF_ASSIGN_OR_RETURN(auto op_name,
224                       GetTensorFlowOpName(inst->getName().getStringRef()));
225   const tensorflow::OpRegistrationData* op_reg_data =
226       tensorflow::OpRegistry::Global()->LookUp(op_name.str());
227   TF_RETURN_IF_ERROR(GetAttrValuesFromOperation(inst, name, op_reg_data,
228                                                 ignore_unregistered_attrs,
229                                                 node_def->mutable_attr()));
230   return node_def;
231 }
232 
233 }  // namespace tensorflow
234