1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Identifier.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/IR/OperationSupport.h" // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/core/common_runtime/graph_constructor.h"
46 #include "tensorflow/core/framework/attr_value.pb.h"
47 #include "tensorflow/core/framework/graph.pb.h"
48 #include "tensorflow/core/framework/graph_to_functiondef.h"
49 #include "tensorflow/core/framework/node_def.pb.h"
50 #include "tensorflow/core/framework/node_def_util.h"
51 #include "tensorflow/core/framework/op.h"
52 #include "tensorflow/core/framework/tensor.pb.h"
53 #include "tensorflow/core/framework/tensor_shape.pb.h"
54 #include "tensorflow/core/framework/types.pb.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/platform/protobuf.h"
59
60 namespace tensorflow {
61 namespace {
62 // static TensorFlow op prefix set.
GlobalOpPrefixes()63 std::set<std::string>* GlobalOpPrefixes() {
64 static std::set<std::string>* global_op_prefixes = [] {
65 std::set<std::string>* result = new std::set<std::string>;
66 result->insert("tf.");
67 result->insert("tf_executor.");
68 return result;
69 }();
70 return global_op_prefixes;
71 }
72
73 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)74 Status ConvertLocation(mlir::Location inst_loc,
75 NodeDef::ExperimentalDebugInfo* debug_info) {
76 if (auto call_site = inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
77 if (auto name_loc = call_site.getCallee().dyn_cast<mlir::NameLoc>()) {
78 debug_info->add_original_node_names(name_loc.getName().c_str());
79 }
80 } else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
81 auto locations = fused.getLocations();
82 if (locations.size() <= 1)
83 return errors::InvalidArgument("expected experimental debuf info.");
84 // skip the first one, which is the name of the node_def.
85 for (int i = 0, end = locations.size() - 1; i < end; ++i) {
86 TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
87 }
88 }
89 return Status::OK();
90 }
91
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)92 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
93 value->set_b(attr.getValue());
94 return Status::OK();
95 }
96
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)97 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
98 value->set_i(attr.getInt());
99 return Status::OK();
100 }
101
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)102 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
103 value->set_f(attr.getValueAsDouble());
104 return Status::OK();
105 }
106
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)107 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
108 return ConvertToTensorProto(attr, value->mutable_tensor());
109 }
110
ConvertAttribute(const mlir::TF::ShapeAttr & attr,AttrValue * value)111 Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
112 auto* shape = value->mutable_shape();
113 if (attr.hasRank()) {
114 for (auto dim_size : attr.getShape()) {
115 auto* dim = shape->add_dim();
116 dim->set_size(dim_size);
117 }
118 } else {
119 shape->set_unknown_rank(true);
120 }
121 return Status::OK();
122 }
123
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)124 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
125 value->mutable_func()->set_name(attr.getValue().str());
126 return Status::OK();
127 }
128
ConvertAttribute(const mlir::TF::FuncAttr & attr,bool remove_ref_type,AttrValue * value)129 Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type,
130 AttrValue* value) {
131 TF_RETURN_IF_ERROR(
132 ConvertAttribute(attr.GetName().cast<mlir::FlatSymbolRefAttr>(), value));
133 TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(),
134 /*attrs_to_ignore=*/{}, remove_ref_type,
135 value->mutable_func()->mutable_attr()));
136 return Status::OK();
137 }
138
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)139 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
140 absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
141 switch (mangling_util::GetMangledKind(attr_value)) {
142 case mangling_util::MangledKind::kUnknown: {
143 value->set_s(std::string(attr_value));
144 return Status::OK();
145 }
146 case mangling_util::MangledKind::kDataType: {
147 DataType dtype;
148 TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
149 value->set_type(dtype);
150 return Status::OK();
151 }
152 case mangling_util::MangledKind::kTensorShape:
153 TF_RETURN_IF_ERROR(
154 mangling_util::DemangleShape(attr_value, value->mutable_shape()));
155 return Status::OK();
156 default:
157 return errors::Unimplemented("Mangled string couldn't be handled!");
158 }
159 return Status::OK();
160 }
161
ConvertAttribute(mlir::Type type,bool remove_ref_type,AttrValue * value)162 Status ConvertAttribute(mlir::Type type, bool remove_ref_type,
163 AttrValue* value) {
164 DataType dtype;
165 TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
166 if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
167 value->set_type(dtype);
168 return Status::OK();
169 }
170
ConvertAttribute(const mlir::TypeAttr & type,bool remove_ref_type,AttrValue * value)171 Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type,
172 AttrValue* value) {
173 return ConvertAttribute(type.getValue(), remove_ref_type, value);
174 }
175
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)176 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
177 value->clear_value();
178 return Status::OK();
179 }
180
ConvertAttribute(const mlir::ArrayAttr & attr,bool remove_ref_type,AttrValue * value)181 Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type,
182 AttrValue* value) {
183 auto* list = value->mutable_list();
184 for (mlir::Attribute a : attr.getValue()) {
185 if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
186 list->add_b(attr.getValue());
187 } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
188 list->add_i(attr.getInt());
189 } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
190 list->add_f(attr.getValueAsDouble());
191 } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
192 AttrValue nested_value;
193 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
194 switch (nested_value.value_case()) {
195 case AttrValue::kS:
196 list->add_s(nested_value.s());
197 break;
198 case AttrValue::kType:
199 list->add_type(nested_value.type());
200 break;
201 case AttrValue::kShape:
202 *list->add_shape() = nested_value.shape();
203 break;
204 default:
205 return errors::Unimplemented("Unhandled nested attribute!");
206 }
207 } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
208 TensorProto tensor;
209 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
210 *list->add_tensor() = tensor;
211 } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
212 AttrValue attr_val;
213 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
214 *list->add_func() = attr_val.func();
215 } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
216 AttrValue attr_val;
217 // For type attributes, we only propagate the element type.
218 mlir::Type elt_type = attr.getValue();
219 if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
220 elt_type = shaped_type.getElementType();
221 }
222 TF_RETURN_IF_ERROR(
223 ConvertAttribute(elt_type, remove_ref_type, &attr_val));
224 list->add_type(attr_val.type());
225 } else if (auto attr = a.dyn_cast<mlir::TF::ShapeAttr>()) {
226 AttrValue attr_val;
227 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
228 *list->add_shape() = attr_val.shape();
229 } else {
230 return errors::Unimplemented("Unhandled attribute!");
231 }
232 }
233 return Status::OK();
234 }
235
236 // Returns true if the executor/control dialect op should map to Ref node in
237 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
238 // type. For executor dialect NextIteration it uses the 2nd operand type. For
239 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
240 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)241 static bool IsRefTypeControlOp(mlir::Operation* op) {
242 if (auto next_iter_sink =
243 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
244 return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
245 .isa<mlir::TF::TensorFlowRefType>();
246
247 auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
248 if (!op_name_or_status.ok()) return false;
249
250 auto op_name = op_name_or_status.ConsumeValueOrDie();
251 if (op_name.equals("NextIteration"))
252 return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
253 .isa<mlir::TF::TensorFlowRefType>();
254
255 if (op_name.equals("Enter") || op_name.equals("Exit") ||
256 op_name.equals("Switch") || op_name.equals("Merge")) {
257 return getElementTypeOrSelf(op->getResult(0).getType())
258 .isa<mlir::TF::TensorFlowRefType>();
259 }
260 return false;
261 }
262
263 } // anonymous namespace
264
GetTensorFlowOpName(llvm::StringRef op_name)265 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
266 // When being converted to MLIR, some prefixes and suffixes are added to the
267 // operation types, and we have to remove them when converting the
268 // operations back to a graph:
269 // - "tf." or "tf_executor." : every operation type has this prefix.
270 // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
271 // don't need to consider ".source"/".Source" because the nodes with this
272 // suffix are skipped by the caller and will not be added to the graph.
273 auto prefixes = GlobalOpPrefixes();
274 if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
275 return op_name.consume_front(prefix);
276 })) {
277 return errors::FailedPrecondition("op node '", op_name.str(),
278 "' was not a TF op!");
279 }
280 // Control dialect NextIteration sink ends with ".sink" and Executor dialect
281 // NextIteration sink ends with ".Sink".
282 if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
283 return op_name;
284 }
285
GetOperationNodeDef(mlir::Operation * inst,llvm::StringRef name)286 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
287 mlir::Operation* inst, llvm::StringRef name) {
288 auto node_def = absl::make_unique<NodeDef>();
289 // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
290 // mapping back from the inputs to the input arguments.
291
292 llvm::SmallString<64> op_name;
293 if (IsLegacyCallInstruction(inst)) {
294 // The op_name is the name of the function.
295 op_name.append(
296 inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
297 // Remove the attribute from the instruction as it is already converted to
298 // op_name.
299 auto attr_id = mlir::Identifier::get("f", inst->getContext());
300 inst->removeAttr(attr_id);
301 } else {
302 // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
303 // as well. For example there is Enter and RefEnter op. RefEnter forwards
304 // the input ref buffer to output. However both Enter and RefEnter are
305 // mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
306 // correctly map to the TensorFlow Graph op.
307 if (IsRefTypeControlOp(inst)) op_name = "Ref";
308 TF_ASSIGN_OR_RETURN(auto tf_name,
309 GetTensorFlowOpName(inst->getName().getStringRef()));
310 op_name.append(tf_name);
311 }
312
313 node_def->set_name(name.str());
314 node_def->set_op(std::string(op_name.str()));
315
316 // Update NodeDef constructed out of an MLIR Case/If/While op to map it to
317 // either TensorFlow StatelessX or X op depending on the additional attribute.
318 if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) {
319 auto stateless = inst->getAttrOfType<mlir::BoolAttr>("is_stateless");
320 if (stateless && stateless.getValue())
321 *node_def->mutable_op() = "Stateless" + node_def->op();
322 }
323
324 // Add inputs to the NodeDef based on the number of operands. This is required
325 // as later when edges are added to the Node using Graph::AddEdge the
326 // associated NodeDef is not updated.
327 for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
328 node_def->add_input();
329 }
330 if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
331 node_def->set_device(std::string(attr.getValue()));
332 }
333
334 // Add the node debug info.
335 TF_RETURN_IF_ERROR(ConvertLocation(
336 inst->getLoc(), node_def->mutable_experimental_debug_info()));
337
338 return node_def;
339 }
340
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)341 Status ConvertAttributes(
342 const llvm::ArrayRef<mlir::NamedAttribute> attrs,
343 const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
344 bool remove_ref_type, AttrValueMap* values) {
345 AttrValueMap func_call_attrs;
346 for (const mlir::NamedAttribute& named_attr : attrs) {
347 auto name_strref = named_attr.first.str();
348 auto attr = named_attr.second;
349 absl::string_view name(name_strref.data(), name_strref.size());
350 if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
351 // The name, device spec of a TF op or function are not stored as
352 // AttrValue inside NodeDef, but we model them using attribute inside
353 // MLIR. So we need to ignore them when going back to AttrValue here.
354 continue;
355 }
356 if (mangling_util::IsMangledAttributeName(name)) {
357 // In MLIR, attributes for functions requires dialect prefix. We need to
358 // remove TF dialect prefix before converting to AttrValue.
359 name = mangling_util::DemangleAttributeName(name);
360 }
361 AttrValue value;
362 if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
363 TF_RETURN_IF_ERROR(
364 ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
365 func_call_attrs[string(name)] = value;
366 continue;
367 }
368 if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
369 TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value));
370 func_call_attrs[string(name)] = value;
371 continue;
372 }
373 if (attr.isa<mlir::AffineMapAttr>()) {
374 // AffineMapAttr is not implemented.
375 return errors::Unimplemented("AffineMap attribute (needed for '",
376 name_strref, "') unimplemented");
377 }
378 TF_RETURN_IF_ERROR(
379 llvm::TypeSwitch<mlir::Attribute, Status>(attr)
380 .Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
381 mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr,
382 mlir::TF::ShapeAttr>([&](auto derived_attr) {
383 return ConvertAttribute(derived_attr, &value);
384 })
385 .Case<mlir::ArrayAttr, mlir::TypeAttr>([&](auto derived_attr) {
386 return ConvertAttribute(derived_attr, remove_ref_type, &value);
387 })
388 .Default([&](mlir::Attribute) {
389 return errors::Unimplemented(
390 "Unhandled attribute kind for attribute '", name_strref,
391 '\'');
392 }));
393
394 // According to the NodeDef proto definition, an attribute name from the
395 // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
396 // the attribute from MLIR, it is treated as an attribute from function
397 // calls.
398 std::vector<string> name_tokens =
399 absl::StrSplit(name, '.', absl::SkipEmpty());
400 TF_RET_CHECK(name_tokens.size() <= 2);
401 auto it = func_call_attrs.find(name_tokens[0]);
402 if (it == func_call_attrs.end()) {
403 (*values)[string(name)] = value;
404 } else {
405 (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
406 }
407 }
408 for (const auto& it : func_call_attrs) {
409 (*values)[it.first] = it.second;
410 }
411 return Status::OK();
412 }
413
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)414 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
415 AttrValueMap* values) {
416 tensorflow::TensorShapeProto tshape;
417 AttrValue value;
418 if (shaped_type.hasRank()) {
419 for (auto dim : shaped_type.getShape()) tshape.add_dim()->set_size(dim);
420 } else {
421 tshape.set_unknown_rank(true);
422 }
423 *value.mutable_shape() = tshape;
424
425 auto result = values->insert({string(name), value});
426 if (!result.second) {
427 // This should be extremely rare as it means we are adding the same
428 // attribute multiple times/have some redundancy in representing this
429 // attribute.
430 TensorShapeProto actual_shape = result.first->second.shape();
431 // Just check via string output as we shouldn't get here and if we do they
432 // should be trivially the same, else fail.
433 if (actual_shape.ShortDebugString() != tshape.ShortDebugString()) {
434 return errors::InvalidArgument("Expected ", tshape.ShortDebugString(),
435 " '", name, "' attribute but found ",
436 actual_shape.ShortDebugString());
437 }
438 }
439 return Status::OK();
440 }
441
IsLegacyCallInstruction(mlir::Operation * inst)442 bool IsLegacyCallInstruction(mlir::Operation* inst) {
443 return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
444 }
445
AddTensorFlowOpPrefix(std::string prefix)446 Status AddTensorFlowOpPrefix(std::string prefix) {
447 GlobalOpPrefixes()->insert(prefix);
448 return Status::OK();
449 }
450
451 } // namespace tensorflow
452