1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
27 #include "mlir/IR/Attributes.h" // TF:llvm-project
28 #include "mlir/IR/Function.h" // TF:llvm-project
29 #include "mlir/IR/Identifier.h" // TF:llvm-project
30 #include "mlir/IR/Location.h" // TF:llvm-project
31 #include "mlir/IR/Module.h" // TF:llvm-project
32 #include "mlir/IR/Operation.h" // TF:llvm-project
33 #include "mlir/IR/OperationSupport.h" // TF:llvm-project
34 #include "mlir/IR/StandardTypes.h" // TF:llvm-project
35 #include "mlir/IR/TypeUtilities.h" // TF:llvm-project
36 #include "mlir/Support/DebugStringHelper.h" // TF:llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/core/framework/attr_value.pb.h"
45 #include "tensorflow/core/framework/graph.pb.h"
46 #include "tensorflow/core/framework/graph_to_functiondef.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/op.h"
50 #include "tensorflow/core/framework/tensor.pb.h"
51 #include "tensorflow/core/framework/tensor_shape.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/graph/algorithm.h"
54 #include "tensorflow/core/graph/graph.h"
55 #include "tensorflow/core/graph/graph_constructor.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/platform/protobuf.h"
58
59 namespace tensorflow {
60 namespace {
61 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)62 Status ConvertLocation(mlir::Location inst_loc,
63 NodeDef::ExperimentalDebugInfo* debug_info) {
64 if (auto call_site = inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
65 if (auto name_loc = call_site.getCallee().dyn_cast<mlir::NameLoc>()) {
66 debug_info->add_original_node_names(name_loc.getName().c_str());
67 }
68 } else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
69 auto locations = fused.getLocations();
70 if (locations.size() <= 1)
71 return errors::InvalidArgument("expected experimental debuf info.");
72 // skip the first one, which is the name of the node_def.
73 for (int i = 0; i < locations.size() - 1; ++i) {
74 TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
75 }
76 }
77 return Status::OK();
78 }
79
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)80 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
81 value->set_b(attr.getValue());
82 return Status::OK();
83 }
84
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)85 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
86 value->set_i(attr.getInt());
87 return Status::OK();
88 }
89
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)90 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
91 value->set_f(attr.getValueAsDouble());
92 return Status::OK();
93 }
94
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)95 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
96 return ConvertToTensorProto(attr, value->mutable_tensor());
97 }
98
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)99 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
100 absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
101 switch (mangling_util::GetMangledKind(attr_value)) {
102 case mangling_util::MangledKind::kUnknown: {
103 value->set_s(std::string(attr_value));
104 return Status::OK();
105 }
106 case mangling_util::MangledKind::kDataType: {
107 DataType dtype;
108 TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
109 value->set_type(dtype);
110 return Status::OK();
111 }
112 case mangling_util::MangledKind::kTensorShape:
113 TF_RETURN_IF_ERROR(
114 mangling_util::DemangleShape(attr_value, value->mutable_shape()));
115 return Status::OK();
116 default:
117 return errors::Unimplemented("Mangled string couldn't be handled!");
118 }
119 return Status::OK();
120 }
121
ConvertAttribute(mlir::Type type,AttrValue * value)122 Status ConvertAttribute(mlir::Type type, AttrValue* value) {
123 DataType dtype;
124 TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
125 value->set_type(dtype);
126 return Status::OK();
127 }
128
ConvertAttribute(const mlir::TypeAttr & type,AttrValue * value)129 Status ConvertAttribute(const mlir::TypeAttr& type, AttrValue* value) {
130 return ConvertAttribute(type.getValue(), value);
131 }
132
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)133 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
134 value->clear_value();
135 return Status::OK();
136 }
137
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)138 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
139 value->mutable_func()->set_name(std::string(attr.getValue()));
140 return Status::OK();
141 }
142
ConvertAttribute(const mlir::ArrayAttr & attr,AttrValue * value)143 Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
144 auto* list = value->mutable_list();
145 for (mlir::Attribute a : attr.getValue()) {
146 if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
147 list->add_b(attr.getValue());
148 } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
149 list->add_i(attr.getInt());
150 } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
151 list->add_f(attr.getValueAsDouble());
152 } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
153 AttrValue nested_value;
154 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
155 switch (nested_value.value_case()) {
156 case AttrValue::kS:
157 list->add_s(nested_value.s());
158 break;
159 case AttrValue::kType:
160 list->add_type(nested_value.type());
161 break;
162 case AttrValue::kShape:
163 *list->add_shape() = nested_value.shape();
164 break;
165 default:
166 return errors::Unimplemented("Unhandled nested attribute!");
167 }
168 } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
169 TensorProto tensor;
170 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
171 *list->add_tensor() = tensor;
172 } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
173 AttrValue attr_val;
174 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
175 *list->add_func() = attr_val.func();
176 } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
177 AttrValue attr_val;
178 // For type attributes, we only propagate the element type.
179 mlir::Type elt_type = attr.getValue();
180 if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
181 elt_type = shaped_type.getElementType();
182 }
183 TF_RETURN_IF_ERROR(ConvertAttribute(elt_type, &attr_val));
184 list->add_type(attr_val.type());
185 } else {
186 return errors::Unimplemented("Unhandled attribute!");
187 }
188 }
189 return Status::OK();
190 }
191
192 // Updates NodeDef constructed out of an MLIR If op to map it to either
193 // TensorFlow StatelessIf or If op depending on the additional attribute.
UpdateCompositeIfOp(NodeDef * node_def)194 void UpdateCompositeIfOp(NodeDef* node_def) {
195 auto it = node_def->mutable_attr()->find("is_stateless");
196 if (it != node_def->attr().end()) {
197 if (it->second.b()) {
198 *node_def->mutable_op() = "StatelessIf";
199 }
200 node_def->mutable_attr()->erase(it);
201 }
202 }
203
204 // Updates NodeDef constructed out of an MLIR While op to map it to either
205 // TensorFlow StatelessWhile or While op depending on the additional attribute.
UpdateCompositeWhileOp(NodeDef * node_def)206 void UpdateCompositeWhileOp(NodeDef* node_def) {
207 auto it = node_def->mutable_attr()->find("is_stateless");
208 if (it != node_def->attr().end()) {
209 if (it->second.b()) {
210 *node_def->mutable_op() = "StatelessWhile";
211 }
212 node_def->mutable_attr()->erase(it);
213 }
214 }
215
216 // Returns true if the executor/control dialect op should map to Ref node in
217 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
218 // type. For executor dialect NextIteration it uses the 2nd operand type. For
219 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
220 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)221 static bool IsRefTypeControlOp(mlir::Operation* op) {
222 if (auto next_iter_sink =
223 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
224 return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
225 .isa<mlir::TF::TensorFlowRefType>();
226
227 auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
228 if (!op_name_or_status.ok()) return false;
229
230 auto op_name = op_name_or_status.ConsumeValueOrDie();
231 if (op_name.equals("NextIteration"))
232 return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
233 .isa<mlir::TF::TensorFlowRefType>();
234
235 if (op_name.equals("Enter") || op_name.equals("Exit") ||
236 op_name.equals("Switch") || op_name.equals("Merge")) {
237 return getElementTypeOrSelf(op->getResult(0).getType())
238 .isa<mlir::TF::TensorFlowRefType>();
239 }
240 return false;
241 }
242
243 } // anonymous namespace
244
GetTensorFlowOpName(llvm::StringRef op_name)245 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
246 // When being converted to MLIR, some prefixes and suffixes are added to the
247 // operation types, and we have to remove them when converting the
248 // operations back to a graph:
249 // - "_tf.", "tf." or "tf_executor." : every operation type has this prefix.
250 // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
251 // don't need to consider ".source"/".Source" because the nodes with this
252 // suffix are skipped by the caller and will not be added to the graph.
253 if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
254 !op_name.consume_front("tf_executor.")) {
255 return errors::FailedPrecondition("op node '", op_name.str(),
256 "' was not a TF op!");
257 }
258 // Control dialect NextIteration sink ends with ".sink" and Executor dialect
259 // NextIteration sink ends with ".Sink".
260 if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
261 return op_name;
262 }
263
GetOperationNodeDef(const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,mlir::Operation * inst,llvm::StringRef name)264 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
265 const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
266 mlir::Operation* inst, llvm::StringRef name) {
267 auto node_def = absl::make_unique<NodeDef>();
268 // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
269 // mapping back from the inputs to the input arguments.
270
271 llvm::SmallString<64> op_name;
272 if (IsLegacyCallInstruction(inst)) {
273 // The op_name is the name of the function.
274 op_name.append(
275 inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
276 // Remove the attribute from the instruction as it is already converted to
277 // op_name.
278 auto attr_id = mlir::Identifier::get("f", inst->getContext());
279 inst->removeAttr(attr_id);
280 } else {
281 // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
282 // as well. For example there is Enter and RefEnter op. RefEnter forwards
283 // the input ref buffer to output. However both Enter and RefEnter are
284 // mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
285 // control dialect. Check if it is a Ref op to correctly map to the
286 // TensorFlow Graph op.
287 if (IsRefTypeControlOp(inst)) op_name = "Ref";
288 TF_ASSIGN_OR_RETURN(auto tf_name,
289 GetTensorFlowOpName(inst->getName().getStringRef()));
290 op_name.append(tf_name);
291 }
292
293 node_def->set_name(name.str());
294 node_def->set_op(std::string(op_name.str()));
295
296 // Add inputs to the NodeDef based on the number of operands. This is required
297 // as later when edges are added to the Node using Graph::AddEdge the
298 // associated NodeDef is not updated.
299 for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
300 node_def->add_input();
301 }
302 if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
303 node_def->set_device(std::string(attr.getValue()));
304 }
305
306 // Add the node attributes.
307 TF_RETURN_WITH_CONTEXT_IF_ERROR(
308 ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
309 node_def->mutable_attr()),
310 "while converting attributes for node: ", name.str());
311
312 // Add the node debug info.
313 TF_RETURN_IF_ERROR(ConvertLocation(
314 inst->getLoc(), node_def->mutable_experimental_debug_info()));
315
316 if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
317 if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get());
318
319 return node_def;
320 }
321
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,AttrValueMap * values)322 Status ConvertAttributes(
323 const llvm::ArrayRef<mlir::NamedAttribute> attrs,
324 const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
325 AttrValueMap* values) {
326 AttrValueMap func_call_attrs;
327 for (const mlir::NamedAttribute& named_attr : attrs) {
328 auto name_strref = named_attr.first.str();
329 auto attr = named_attr.second;
330 absl::string_view name(name_strref.data(), name_strref.size());
331 if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
332 // The name, device spec of a TF op or function are not stored as
333 // AttrValue inside NodeDef, but we model them using attribute inside
334 // MLIR. So we need to ignore them when going back to AttrValue here.
335 continue;
336 }
337 if (mangling_util::IsMangledAttributeName(name)) {
338 // In MLIR, attributes for functions requires dialect prefix. We need to
339 // remove TF dialect prefix before converting to AttrValue.
340 name = mangling_util::DemangleAttributeName(name);
341 }
342 AttrValue value;
343 switch (attr.getKind()) {
344 case mlir::StandardAttributes::SymbolRef: {
345 auto func_attr = attr.cast<mlir::FlatSymbolRefAttr>();
346 value.mutable_func()->set_name(std::string(func_attr.getValue()));
347 func_call_attrs[string(name)] = value;
348 continue;
349 }
350 case mlir::StandardAttributes::Bool:
351 TF_RETURN_IF_ERROR(
352 ConvertAttribute(attr.cast<mlir::BoolAttr>(), &value));
353 break;
354 case mlir::StandardAttributes::Integer:
355 TF_RETURN_IF_ERROR(
356 ConvertAttribute(attr.cast<mlir::IntegerAttr>(), &value));
357 break;
358 case mlir::StandardAttributes::Float:
359 TF_RETURN_IF_ERROR(
360 ConvertAttribute(attr.cast<mlir::FloatAttr>(), &value));
361 break;
362 case mlir::StandardAttributes::String:
363 TF_RETURN_IF_ERROR(
364 ConvertAttribute(attr.cast<mlir::StringAttr>(), &value));
365 break;
366 case mlir::StandardAttributes::Array:
367 TF_RETURN_IF_ERROR(
368 ConvertAttribute(attr.cast<mlir::ArrayAttr>(), &value));
369 break;
370 case mlir::StandardAttributes::DenseElements:
371 case mlir::StandardAttributes::OpaqueElements:
372 TF_RETURN_IF_ERROR(
373 ConvertAttribute(attr.cast<mlir::ElementsAttr>(), &value));
374 break;
375 case mlir::StandardAttributes::Type:
376 TF_RETURN_IF_ERROR(
377 ConvertAttribute(attr.cast<mlir::TypeAttr>(), &value));
378 break;
379 case mlir::StandardAttributes::Unit:
380 TF_RETURN_IF_ERROR(
381 ConvertAttribute(attr.cast<mlir::UnitAttr>(), &value));
382 break;
383 // AffineMap kind is not implemented.
384 case mlir::StandardAttributes::AffineMap:
385 return errors::Unimplemented("AffineMap attribute (needed for '",
386 name_strref, "') unimplemented");
387 default:
388 return errors::Unimplemented("Unhandled attribute kind for attribute '",
389 name_strref, '\'');
390 }
391 // According to the NodeDef proto definition, an attribute name from the
392 // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
393 // the attribute from MLIR, it is treated as an attribute from function
394 // calls.
395 std::vector<string> name_tokens =
396 absl::StrSplit(name, '.', absl::SkipEmpty());
397 TF_RET_CHECK(name_tokens.size() <= 2);
398 auto it = func_call_attrs.find(name_tokens[0]);
399 if (it == func_call_attrs.end()) {
400 (*values)[string(name)] = value;
401 } else {
402 (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
403 }
404 }
405 for (const auto& it : func_call_attrs) {
406 (*values)[it.first] = it.second;
407 }
408 return Status::OK();
409 }
410
411 // Sets type attribute with the given name. If the attribute already exists with
412 // a different value, returns an error.
SetTypeAttribute(absl::string_view name,mlir::Type type,AttrValueMap * values)413 Status SetTypeAttribute(absl::string_view name, mlir::Type type,
414 AttrValueMap* values) {
415 DataType dtype;
416 TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype));
417 if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
418 AttrValue value;
419 value.set_type(dtype);
420
421 auto result = values->insert({string(name), value});
422 if (!result.second) {
423 DataType actual_dtype = result.first->second.type();
424 if (actual_dtype != dtype) {
425 return errors::InvalidArgument("Expected ", DataType_Name(dtype), " '",
426 name, "' attribute but found ",
427 DataType_Name(actual_dtype));
428 }
429 }
430 return Status::OK();
431 }
432
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)433 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
434 AttrValueMap* values) {
435 tensorflow::TensorShapeProto tshape;
436 AttrValue value;
437 if (shaped_type.hasRank()) {
438 for (auto dim : shaped_type.getShape()) tshape.add_dim()->set_size(dim);
439 } else {
440 tshape.set_unknown_rank(true);
441 }
442 *value.mutable_shape() = tshape;
443
444 auto result = values->insert({string(name), value});
445 if (!result.second) {
446 // This should be extremely rare as it means we are adding the same
447 // attribute multiple times/have some redundancy in representing this
448 // attribute.
449 TensorShapeProto actual_shape = result.first->second.shape();
450 // Just check via string output as we shouldn't get here and if we do they
451 // should be trivially the same, else fail.
452 if (actual_shape.ShortDebugString() != tshape.ShortDebugString()) {
453 return errors::InvalidArgument("Expected ", tshape.ShortDebugString(),
454 " '", name, "' attribute but found ",
455 actual_shape.ShortDebugString());
456 }
457 }
458 return Status::OK();
459 }
460
SetSizeAttribute(absl::string_view name,size_t size,AttrValueMap * values)461 Status SetSizeAttribute(absl::string_view name, size_t size,
462 AttrValueMap* values) {
463 AttrValue value;
464 value.set_i(size);
465
466 auto result = values->insert({string(name), value});
467 if (!result.second) {
468 // This should be extremely rare as it means we are adding the same
469 // attribute multiple times/have some redundancy in representing this
470 // attribute.
471 int64 actual_size = result.first->second.i();
472 // Just check via string output as we shouldn't get here and if we do they
473 // should be trivially the same, else fail.
474 if (actual_size != size)
475 return errors::InvalidArgument("Expected '", name, "' attribute to be ",
476 size, " but found ", actual_size);
477 }
478 return Status::OK();
479 }
480
IsLegacyCallInstruction(mlir::Operation * inst)481 bool IsLegacyCallInstruction(mlir::Operation* inst) {
482 return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst) ||
483 inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
484 }
485
486 } // namespace tensorflow
487